From 8f8cbd22d352fc59dfd5bf19de979b05bb5c7938 Mon Sep 17 00:00:00 2001 From: j4ckl1u Date: Wed, 14 Jul 2021 12:51:22 +0800 Subject: [PATCH] add code --- LICENSE | 74 ++ README.md | 58 + config/config.yaml | 111 ++ config/criterion/adaptive_loss.yaml | 3 + config/criterion/cross_entropy.yaml | 2 + config/lr_scheduler/cosine.yaml | 7 + config/lr_scheduler/inverse_sqrt.yaml | 3 + config/model/transformer_lm.yaml | 36 + config/model/transformer_lm_baevski_gbw.yaml | 36 + .../model/transformer_lm_baevski_wiki103.yaml | 36 + config/model/transformer_lm_big.yaml | 36 + config/model/transformer_lm_gbw.yaml | 36 + config/model/transformer_lm_gpt.yaml | 36 + config/model/transformer_lm_gpt2_big.yaml | 36 + config/model/transformer_lm_gpt2_medium.yaml | 36 + config/model/transformer_lm_gpt2_small.yaml | 36 + config/model/transformer_lm_wiki103.yaml | 36 + config/optimizer/adam.yaml | 5 + config/optimizer/nag.yaml | 3 + config/task/language_modeling.yaml | 10 + examples/.gitignore | 2 + examples/__init__.py | 6 + examples/speech_recognition/README.md | 106 ++ examples/speech_recognition/__init__.py | 1 + .../speech_recognition/criterions/ASG_loss.py | 170 +++ .../speech_recognition/criterions/__init__.py | 17 + .../criterions/cross_entropy_acc.py | 130 ++ examples/speech_recognition/data/__init__.py | 11 + .../speech_recognition/data/asr_dataset.py | 122 ++ examples/speech_recognition/data/collaters.py | 131 ++ .../speech_recognition/data/data_utils.py | 100 ++ examples/speech_recognition/data/replabels.py | 70 + .../datasets/asr_prep_json.py | 125 ++ .../datasets/prepare-librispeech.sh | 88 ++ examples/speech_recognition/infer.py | 474 +++++++ .../speech_recognition/models/__init__.py | 8 + .../models/vggtransformer.py | 1019 +++++++++++++++ .../models/w2l_conv_glu_enc.py | 177 +++ examples/speech_recognition/tasks/__init__.py | 8 + .../tasks/speech_recognition.py | 157 +++ .../speech_recognition/utils/wer_utils.py | 381 ++++++ examples/speech_recognition/w2l_decoder.py | 453 +++++++ examples/unispeech/adjust_sample_rate.py | 60 + examples/unispeech/libri_labels.py | 56 + .../unispeech/scripts/continue_pretrain.sh | 11 + examples/unispeech/scripts/finetune.sh | 11 + .../scripts/multilingal_large_pretrain.sh | 13 + .../scripts/one2one_large_pretrain_en1350.sh | 12 + examples/unispeech/unispeech_manifest.py | 46 + examples/unispeech/wav2vec_manifest.py | 76 ++ fairseq/__init__.py | 41 + fairseq/benchmark/__init__.py | 7 + fairseq/benchmark/dummy_lm.py | 125 ++ fairseq/benchmark/dummy_masked_lm.py | 127 ++ fairseq/benchmark/dummy_model.py | 96 ++ fairseq/benchmark/dummy_mt.py | 119 ++ fairseq/binarizer.py | 105 ++ fairseq/checkpoint_utils.py | 610 +++++++++ fairseq/clib/libbleu/libbleu.cpp | 141 ++ fairseq/clib/libbleu/module.cpp | 37 + fairseq/clib/libnat/edit_dist.cpp | 231 ++++ fairseq/clib/libnat_cuda/binding.cpp | 60 + fairseq/clib/libnat_cuda/edit_dist.cu | 332 +++++ fairseq/clib/libnat_cuda/edit_dist.h | 25 + fairseq/criterions/__init__.py | 36 + fairseq/criterions/ctc.py | 263 ++++ fairseq/criterions/fairseq_criterion.py | 118 ++ .../label_smoothed_cross_entropy.py | 160 +++ ...l_smoothed_cross_entropy_with_alignment.py | 125 ++ fairseq/criterions/wav2vec_criterion.py | 192 +++ fairseq/criterions/wav2vec_mtl.py | 137 ++ fairseq/data/__init__.py | 39 + fairseq/data/add_target_dataset.py | 70 + fairseq/data/audio/__init__.py | 0 fairseq/data/audio/audio_utils.py | 85 ++ .../data/audio/feature_transforms/__init__.py | 82 ++ .../audio/feature_transforms/global_cmvn.py | 25 + .../audio/feature_transforms/specaugment.py | 131 ++ .../feature_transforms/utterance_cmvn.py | 40 + fairseq/data/audio/raw_audio_dataset.py | 192 +++ fairseq/data/audio/speech_to_text_dataset.py | 528 ++++++++ fairseq/data/base_wrapper_dataset.py | 78 ++ fairseq/data/concat_dataset.py | 124 ++ fairseq/data/concat_sentences_dataset.py | 54 + fairseq/data/data_utils.py | 499 +++++++ fairseq/data/data_utils_fast.pyx | 123 ++ fairseq/data/dictionary.py | 418 ++++++ fairseq/data/encoders/__init__.py | 29 + fairseq/data/encoders/byte_bpe.py | 48 + fairseq/data/encoders/byte_utils.py | 51 + fairseq/data/encoders/bytes.py | 34 + fairseq/data/encoders/characters.py | 30 + fairseq/data/encoders/fastbpe.py | 36 + fairseq/data/encoders/gpt2_bpe.py | 45 + fairseq/data/encoders/gpt2_bpe_utils.py | 140 ++ fairseq/data/encoders/hf_bert_bpe.py | 50 + fairseq/data/encoders/hf_byte_bpe.py | 46 + fairseq/data/encoders/moses_tokenizer.py | 49 + fairseq/data/encoders/nltk_tokenizer.py | 23 + fairseq/data/encoders/sentencepiece_bpe.py | 48 + fairseq/data/encoders/space_tokenizer.py | 20 + fairseq/data/encoders/subword_nmt_bpe.py | 54 + fairseq/data/encoders/utils.py | 30 + fairseq/data/fairseq_dataset.py | 191 +++ fairseq/data/fasta_dataset.py | 107 ++ fairseq/data/id_dataset.py | 19 + fairseq/data/indexed_dataset.py | 561 ++++++++ fairseq/data/iterators.py | 586 +++++++++ fairseq/data/legacy/__init__.py | 16 + fairseq/data/legacy/block_pair_dataset.py | 311 +++++ fairseq/data/legacy/masked_lm_dataset.py | 303 +++++ fairseq/data/legacy/masked_lm_dictionary.py | 60 + fairseq/data/plasma_utils.py | 91 ++ fairseq/data/resampling_dataset.py | 139 ++ fairseq/data/token_block_dataset.py | 168 +++ fairseq/data/token_block_utils_fast.pyx | 187 +++ fairseq/dataclass/__init__.py | 10 + fairseq/dataclass/configs.py | 876 +++++++++++++ fairseq/dataclass/constants.py | 35 + fairseq/dataclass/initialize.py | 48 + fairseq/dataclass/utils.py | 363 ++++++ fairseq/distributed_utils.py | 458 +++++++ fairseq/file_io.py | 116 ++ fairseq/file_utils.py | 353 +++++ fairseq/hub_utils.py | 296 +++++ fairseq/incremental_decoding_utils.py | 51 + fairseq/iterative_refinement_generator.py | 359 +++++ fairseq/legacy_distributed_data_parallel.py | 171 +++ fairseq/logging/__init__.py | 0 fairseq/logging/meters.py | 291 +++++ fairseq/logging/metrics.py | 288 +++++ fairseq/logging/progress_bar.py | 355 +++++ fairseq/model_parallel/__init__.py | 6 + fairseq/model_parallel/criterions/__init__.py | 14 + .../vocab_parallel_cross_entropy.py | 87 ++ fairseq/model_parallel/megatron_trainer.py | 67 + fairseq/model_parallel/models/__init__.py | 20 + .../pipeline_parallel_transformer/__init__.py | 6 + .../pipeline_parallel_transformer/layers.py | 600 +++++++++ .../pipeline_parallel_transformer/model.py | 711 ++++++++++ fairseq/model_parallel/models/transformer.py | 116 ++ .../model_parallel/models/transformer_lm.py | 174 +++ fairseq/model_parallel/modules/__init__.py | 23 + .../modules/multihead_attention.py | 352 +++++ .../modules/transformer_layer.py | 78 ++ .../modules/transformer_sentence_encoder.py | 59 + .../transformer_sentence_encoder_layer.py | 77 ++ fairseq/models/__init__.py | 187 +++ fairseq/models/composite_encoder.py | 57 + fairseq/models/distributed_fairseq_model.py | 103 ++ fairseq/models/fairseq_decoder.py | 90 ++ fairseq/models/fairseq_encoder.py | 92 ++ fairseq/models/fairseq_incremental_decoder.py | 118 ++ fairseq/models/fairseq_model.py | 573 ++++++++ fairseq/models/transformer.py | 1036 +++++++++++++++ fairseq/models/transformer_align.py | 93 ++ .../models/transformer_from_pretrained_xlm.py | 152 +++ fairseq/models/transformer_lm.py | 397 ++++++ fairseq/models/wav2vec/__init__.py | 9 + fairseq/models/wav2vec/unispeech.py | 187 +++ fairseq/models/wav2vec/wav2vec.py | 735 +++++++++++ fairseq/models/wav2vec/wav2vec2.py | 1047 +++++++++++++++ fairseq/models/wav2vec/wav2vec2_asr.py | 698 ++++++++++ fairseq/models/wav2vec/wav2vec2_s1.py | 1067 +++++++++++++++ fairseq/modules/__init__.py | 76 ++ fairseq/modules/adaptive_input.py | 80 ++ fairseq/modules/adaptive_softmax.py | 268 ++++ fairseq/modules/beamable_mm.py | 49 + fairseq/modules/character_token_embedder.py | 214 +++ fairseq/modules/checkpoint_activations.py | 205 +++ fairseq/modules/conv_tbc.py | 42 + fairseq/modules/cross_entropy.py | 59 + fairseq/modules/cuda_utils.cu | 203 +++ .../downsampled_multihead_attention.py | 316 +++++ fairseq/modules/dynamic_convolution.py | 304 +++++ fairseq/modules/dynamic_crf_layer.py | 189 +++ fairseq/modules/dynamicconv_layer/__init__.py | 6 + .../dynamicconv_layer/cuda_function_gen.py | 223 ++++ .../dynamicconv_layer/dynamicconv_cuda.cpp | 56 + .../dynamicconv_layer/dynamicconv_cuda.cuh | 51 + .../dynamicconv_cuda_kernel.cu | 168 +++ .../dynamicconv_layer/dynamicconv_layer.py | 227 ++++ .../dynamicconv_layer/dynamiconv_cpu.cpp | 35 + fairseq/modules/dynamicconv_layer/setup.py | 23 + fairseq/modules/fairseq_dropout.py | 51 + fairseq/modules/fp32_group_norm.py | 25 + fairseq/modules/gelu.py | 25 + fairseq/modules/grad_multiply.py | 18 + fairseq/modules/gumbel_vector_quantizer.py | 199 +++ fairseq/modules/kmeans_vector_quantizer.py | 127 ++ fairseq/modules/layer_drop.py | 44 + fairseq/modules/layer_norm.py | 50 + .../modules/learned_positional_embedding.py | 61 + fairseq/modules/lightconv_layer/__init__.py | 6 + .../lightconv_layer/cuda_function_gen.py | 289 +++++ .../lightconv_layer/lightconv_cuda.cpp | 54 + .../lightconv_layer/lightconv_cuda.cuh | 83 ++ .../lightconv_layer/lightconv_cuda_kernel.cu | 375 ++++++ .../lightconv_layer/lightconv_layer.py | 137 ++ fairseq/modules/lightconv_layer/setup.py | 23 + fairseq/modules/lightweight_convolution.py | 310 +++++ fairseq/modules/linearized_convolution.py | 104 ++ fairseq/modules/multihead_attention.py | 488 +++++++ fairseq/modules/positional_embedding.py | 35 + fairseq/modules/quant_noise.py | 107 ++ fairseq/modules/quantization/__init__.py | 0 fairseq/modules/quantization/pq/__init__.py | 6 + fairseq/modules/quantization/pq/em.py | 211 +++ .../quantization/pq/modules/__init__.py | 8 + .../modules/quantization/pq/modules/qconv.py | 115 ++ .../modules/quantization/pq/modules/qemb.py | 107 ++ .../quantization/pq/modules/qlinear.py | 71 + fairseq/modules/quantization/pq/pq.py | 128 ++ fairseq/modules/quantization/pq/utils.py | 337 +++++ .../quantization/quantization_options.py | 44 + .../modules/quantization/scalar/__init__.py | 6 + .../quantization/scalar/modules/__init__.py | 9 + .../quantization/scalar/modules/qact.py | 88 ++ .../quantization/scalar/modules/qconv.py | 149 +++ .../quantization/scalar/modules/qemb.py | 147 +++ .../quantization/scalar/modules/qlinear.py | 113 ++ fairseq/modules/quantization/scalar/ops.py | 49 + fairseq/modules/quantization/scalar/utils.py | 77 ++ fairseq/modules/same_pad.py | 18 + fairseq/modules/scalar_bias.py | 31 + .../sinusoidal_positional_embedding.py | 105 ++ fairseq/modules/sparse_multihead_attention.py | 140 ++ .../sparse_transformer_sentence_encoder.py | 96 ++ ...arse_transformer_sentence_encoder_layer.py | 51 + fairseq/modules/transformer_layer.py | 414 ++++++ .../modules/transformer_sentence_encoder.py | 286 ++++ .../transformer_sentence_encoder_layer.py | 134 ++ fairseq/modules/transpose_last.py | 20 + fairseq/modules/unfold.py | 19 + fairseq/modules/vggblock.py | 116 ++ fairseq/nan_detector.py | 108 ++ fairseq/optim/__init__.py | 48 + fairseq/optim/adadelta.py | 47 + fairseq/optim/adafactor.py | 268 ++++ fairseq/optim/adagrad.py | 40 + fairseq/optim/adam.py | 226 ++++ fairseq/optim/adamax.py | 172 +++ fairseq/optim/bmuf.py | 200 +++ fairseq/optim/dynamic_loss_scaler.py | 70 + fairseq/optim/fairseq_optimizer.py | 150 +++ fairseq/optim/fp16_optimizer.py | 496 +++++++ fairseq/optim/fused_adam.py | 348 +++++ fairseq/optim/fused_lamb.py | 51 + fairseq/optim/lr_scheduler/__init__.py | 36 + .../optim/lr_scheduler/cosine_lr_scheduler.py | 154 +++ .../lr_scheduler/fairseq_lr_scheduler.py | 56 + fairseq/optim/lr_scheduler/fixed_schedule.py | 70 + .../inverse_square_root_schedule.py | 94 ++ .../lr_scheduler/polynomial_decay_schedule.py | 82 ++ .../lr_scheduler/reduce_lr_on_plateau.py | 115 ++ .../lr_scheduler/tri_stage_lr_scheduler.py | 165 +++ .../lr_scheduler/triangular_lr_scheduler.py | 74 ++ fairseq/optim/nag.py | 111 ++ fairseq/optim/sgd.py | 43 + fairseq/optim/shard.py | 41 + fairseq/options.py | 363 ++++++ fairseq/pdb.py | 47 + fairseq/quantization_utils.py | 143 ++ fairseq/registry.py | 84 ++ fairseq/scoring/__init__.py | 56 + fairseq/scoring/bleu.py | 167 +++ fairseq/scoring/chrf.py | 27 + fairseq/scoring/tokenizer.py | 67 + fairseq/scoring/wer.py | 58 + fairseq/search.py | 814 ++++++++++++ fairseq/sequence_generator.py | 1005 ++++++++++++++ fairseq/sequence_scorer.py | 153 +++ fairseq/tasks/__init__.py | 105 ++ fairseq/tasks/audio_pretraining.py | 335 +++++ fairseq/tasks/fairseq_task.py | 567 ++++++++ fairseq/token_generation_constraints.py | 506 ++++++++ fairseq/tokenizer.py | 15 + fairseq/trainer.py | 1149 +++++++++++++++++ fairseq/utils.py | 720 +++++++++++ fairseq/version.txt | 1 + fairseq_cli/__init__.py | 0 fairseq_cli/eval_lm.py | 284 ++++ fairseq_cli/generate.py | 393 ++++++ fairseq_cli/interactive.py | 322 +++++ fairseq_cli/preprocess.py | 398 ++++++ fairseq_cli/score.py | 102 ++ fairseq_cli/train.py | 353 +++++ fairseq_cli/validate.py | 140 ++ hubconf.py | 69 + pyproject.toml | 3 + scripts/__init__.py | 0 scripts/average_checkpoints.py | 158 +++ scripts/build_sym_alignment.py | 97 ++ scripts/compare_namespaces.py | 46 + scripts/compound_split_bleu.sh | 20 + scripts/constraints/extract.py | 92 ++ scripts/constraints/validate.py | 34 + scripts/convert_dictionary.lua | 34 + scripts/convert_model.lua | 108 ++ scripts/count_docs.py | 58 + scripts/read_binarized.py | 48 + scripts/rm_pt.py | 141 ++ scripts/sacrebleu.sh | 27 + scripts/shard_docs.py | 54 + scripts/split_train_valid_docs.py | 86 ++ scripts/spm_decode.py | 53 + scripts/spm_encode.py | 119 ++ scripts/spm_train.py | 16 + setup.py | 236 ++++ train.py | 14 + 310 files changed, 48477 insertions(+) create mode 100644 LICENSE create mode 100644 README.md create mode 100644 config/config.yaml create mode 100644 config/criterion/adaptive_loss.yaml create mode 100644 config/criterion/cross_entropy.yaml create mode 100644 config/lr_scheduler/cosine.yaml create mode 100644 config/lr_scheduler/inverse_sqrt.yaml create mode 100644 config/model/transformer_lm.yaml create mode 100644 config/model/transformer_lm_baevski_gbw.yaml create mode 100644 config/model/transformer_lm_baevski_wiki103.yaml create mode 100644 config/model/transformer_lm_big.yaml create mode 100644 config/model/transformer_lm_gbw.yaml create mode 100644 config/model/transformer_lm_gpt.yaml create mode 100644 config/model/transformer_lm_gpt2_big.yaml create mode 100644 config/model/transformer_lm_gpt2_medium.yaml create mode 100644 config/model/transformer_lm_gpt2_small.yaml create mode 100644 config/model/transformer_lm_wiki103.yaml create mode 100644 config/optimizer/adam.yaml create mode 100644 config/optimizer/nag.yaml create mode 100644 config/task/language_modeling.yaml create mode 100644 examples/.gitignore create mode 100644 examples/__init__.py create mode 100644 examples/speech_recognition/README.md create mode 100644 examples/speech_recognition/__init__.py create mode 100644 examples/speech_recognition/criterions/ASG_loss.py create mode 100644 examples/speech_recognition/criterions/__init__.py create mode 100644 examples/speech_recognition/criterions/cross_entropy_acc.py create mode 100644 examples/speech_recognition/data/__init__.py create mode 100644 examples/speech_recognition/data/asr_dataset.py create mode 100644 examples/speech_recognition/data/collaters.py create mode 100644 examples/speech_recognition/data/data_utils.py create mode 100644 examples/speech_recognition/data/replabels.py create mode 100644 examples/speech_recognition/datasets/asr_prep_json.py create mode 100644 examples/speech_recognition/datasets/prepare-librispeech.sh create mode 100644 examples/speech_recognition/infer.py create mode 100644 examples/speech_recognition/models/__init__.py create mode 100644 examples/speech_recognition/models/vggtransformer.py create mode 100644 examples/speech_recognition/models/w2l_conv_glu_enc.py create mode 100644 examples/speech_recognition/tasks/__init__.py create mode 100644 examples/speech_recognition/tasks/speech_recognition.py create mode 100644 examples/speech_recognition/utils/wer_utils.py create mode 100644 examples/speech_recognition/w2l_decoder.py create mode 100644 examples/unispeech/adjust_sample_rate.py create mode 100644 examples/unispeech/libri_labels.py create mode 100644 examples/unispeech/scripts/continue_pretrain.sh create mode 100644 examples/unispeech/scripts/finetune.sh create mode 100644 examples/unispeech/scripts/multilingal_large_pretrain.sh create mode 100644 examples/unispeech/scripts/one2one_large_pretrain_en1350.sh create mode 100644 examples/unispeech/unispeech_manifest.py create mode 100644 examples/unispeech/wav2vec_manifest.py create mode 100644 fairseq/__init__.py create mode 100644 fairseq/benchmark/__init__.py create mode 100644 fairseq/benchmark/dummy_lm.py create mode 100644 fairseq/benchmark/dummy_masked_lm.py create mode 100644 fairseq/benchmark/dummy_model.py create mode 100644 fairseq/benchmark/dummy_mt.py create mode 100644 fairseq/binarizer.py create mode 100644 fairseq/checkpoint_utils.py create mode 100644 fairseq/clib/libbleu/libbleu.cpp create mode 100644 fairseq/clib/libbleu/module.cpp create mode 100644 fairseq/clib/libnat/edit_dist.cpp create mode 100644 fairseq/clib/libnat_cuda/binding.cpp create mode 100644 fairseq/clib/libnat_cuda/edit_dist.cu create mode 100644 fairseq/clib/libnat_cuda/edit_dist.h create mode 100644 fairseq/criterions/__init__.py create mode 100644 fairseq/criterions/ctc.py create mode 100644 fairseq/criterions/fairseq_criterion.py create mode 100644 fairseq/criterions/label_smoothed_cross_entropy.py create mode 100644 fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py create mode 100644 fairseq/criterions/wav2vec_criterion.py create mode 100644 fairseq/criterions/wav2vec_mtl.py create mode 100644 fairseq/data/__init__.py create mode 100644 fairseq/data/add_target_dataset.py create mode 100644 fairseq/data/audio/__init__.py create mode 100644 fairseq/data/audio/audio_utils.py create mode 100644 fairseq/data/audio/feature_transforms/__init__.py create mode 100644 fairseq/data/audio/feature_transforms/global_cmvn.py create mode 100644 fairseq/data/audio/feature_transforms/specaugment.py create mode 100644 fairseq/data/audio/feature_transforms/utterance_cmvn.py create mode 100644 fairseq/data/audio/raw_audio_dataset.py create mode 100644 fairseq/data/audio/speech_to_text_dataset.py create mode 100644 fairseq/data/base_wrapper_dataset.py create mode 100644 fairseq/data/concat_dataset.py create mode 100644 fairseq/data/concat_sentences_dataset.py create mode 100644 fairseq/data/data_utils.py create mode 100644 fairseq/data/data_utils_fast.pyx create mode 100644 fairseq/data/dictionary.py create mode 100644 fairseq/data/encoders/__init__.py create mode 100644 fairseq/data/encoders/byte_bpe.py create mode 100644 fairseq/data/encoders/byte_utils.py create mode 100644 fairseq/data/encoders/bytes.py create mode 100644 fairseq/data/encoders/characters.py create mode 100644 fairseq/data/encoders/fastbpe.py create mode 100644 fairseq/data/encoders/gpt2_bpe.py create mode 100644 fairseq/data/encoders/gpt2_bpe_utils.py create mode 100644 fairseq/data/encoders/hf_bert_bpe.py create mode 100644 fairseq/data/encoders/hf_byte_bpe.py create mode 100644 fairseq/data/encoders/moses_tokenizer.py create mode 100644 fairseq/data/encoders/nltk_tokenizer.py create mode 100644 fairseq/data/encoders/sentencepiece_bpe.py create mode 100644 fairseq/data/encoders/space_tokenizer.py create mode 100644 fairseq/data/encoders/subword_nmt_bpe.py create mode 100644 fairseq/data/encoders/utils.py create mode 100644 fairseq/data/fairseq_dataset.py create mode 100644 fairseq/data/fasta_dataset.py create mode 100644 fairseq/data/id_dataset.py create mode 100644 fairseq/data/indexed_dataset.py create mode 100644 fairseq/data/iterators.py create mode 100644 fairseq/data/legacy/__init__.py create mode 100644 fairseq/data/legacy/block_pair_dataset.py create mode 100644 fairseq/data/legacy/masked_lm_dataset.py create mode 100644 fairseq/data/legacy/masked_lm_dictionary.py create mode 100644 fairseq/data/plasma_utils.py create mode 100644 fairseq/data/resampling_dataset.py create mode 100644 fairseq/data/token_block_dataset.py create mode 100644 fairseq/data/token_block_utils_fast.pyx create mode 100644 fairseq/dataclass/__init__.py create mode 100644 fairseq/dataclass/configs.py create mode 100644 fairseq/dataclass/constants.py create mode 100644 fairseq/dataclass/initialize.py create mode 100644 fairseq/dataclass/utils.py create mode 100644 fairseq/distributed_utils.py create mode 100644 fairseq/file_io.py create mode 100644 fairseq/file_utils.py create mode 100644 fairseq/hub_utils.py create mode 100644 fairseq/incremental_decoding_utils.py create mode 100644 fairseq/iterative_refinement_generator.py create mode 100644 fairseq/legacy_distributed_data_parallel.py create mode 100644 fairseq/logging/__init__.py create mode 100644 fairseq/logging/meters.py create mode 100644 fairseq/logging/metrics.py create mode 100644 fairseq/logging/progress_bar.py create mode 100644 fairseq/model_parallel/__init__.py create mode 100644 fairseq/model_parallel/criterions/__init__.py create mode 100644 fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py create mode 100644 fairseq/model_parallel/megatron_trainer.py create mode 100644 fairseq/model_parallel/models/__init__.py create mode 100644 fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py create mode 100644 fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py create mode 100644 fairseq/model_parallel/models/pipeline_parallel_transformer/model.py create mode 100644 fairseq/model_parallel/models/transformer.py create mode 100644 fairseq/model_parallel/models/transformer_lm.py create mode 100644 fairseq/model_parallel/modules/__init__.py create mode 100644 fairseq/model_parallel/modules/multihead_attention.py create mode 100644 fairseq/model_parallel/modules/transformer_layer.py create mode 100644 fairseq/model_parallel/modules/transformer_sentence_encoder.py create mode 100644 fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py create mode 100644 fairseq/models/__init__.py create mode 100644 fairseq/models/composite_encoder.py create mode 100644 fairseq/models/distributed_fairseq_model.py create mode 100644 fairseq/models/fairseq_decoder.py create mode 100644 fairseq/models/fairseq_encoder.py create mode 100644 fairseq/models/fairseq_incremental_decoder.py create mode 100644 fairseq/models/fairseq_model.py create mode 100644 fairseq/models/transformer.py create mode 100644 fairseq/models/transformer_align.py create mode 100644 fairseq/models/transformer_from_pretrained_xlm.py create mode 100644 fairseq/models/transformer_lm.py create mode 100644 fairseq/models/wav2vec/__init__.py create mode 100644 fairseq/models/wav2vec/unispeech.py create mode 100644 fairseq/models/wav2vec/wav2vec.py create mode 100644 fairseq/models/wav2vec/wav2vec2.py create mode 100644 fairseq/models/wav2vec/wav2vec2_asr.py create mode 100644 fairseq/models/wav2vec/wav2vec2_s1.py create mode 100644 fairseq/modules/__init__.py create mode 100644 fairseq/modules/adaptive_input.py create mode 100644 fairseq/modules/adaptive_softmax.py create mode 100644 fairseq/modules/beamable_mm.py create mode 100644 fairseq/modules/character_token_embedder.py create mode 100644 fairseq/modules/checkpoint_activations.py create mode 100644 fairseq/modules/conv_tbc.py create mode 100644 fairseq/modules/cross_entropy.py create mode 100644 fairseq/modules/cuda_utils.cu create mode 100644 fairseq/modules/downsampled_multihead_attention.py create mode 100644 fairseq/modules/dynamic_convolution.py create mode 100644 fairseq/modules/dynamic_crf_layer.py create mode 100644 fairseq/modules/dynamicconv_layer/__init__.py create mode 100644 fairseq/modules/dynamicconv_layer/cuda_function_gen.py create mode 100644 fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp create mode 100644 fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh create mode 100644 fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu create mode 100644 fairseq/modules/dynamicconv_layer/dynamicconv_layer.py create mode 100644 fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp create mode 100644 fairseq/modules/dynamicconv_layer/setup.py create mode 100644 fairseq/modules/fairseq_dropout.py create mode 100644 fairseq/modules/fp32_group_norm.py create mode 100644 fairseq/modules/gelu.py create mode 100644 fairseq/modules/grad_multiply.py create mode 100644 fairseq/modules/gumbel_vector_quantizer.py create mode 100644 fairseq/modules/kmeans_vector_quantizer.py create mode 100644 fairseq/modules/layer_drop.py create mode 100644 fairseq/modules/layer_norm.py create mode 100644 fairseq/modules/learned_positional_embedding.py create mode 100644 fairseq/modules/lightconv_layer/__init__.py create mode 100644 fairseq/modules/lightconv_layer/cuda_function_gen.py create mode 100644 fairseq/modules/lightconv_layer/lightconv_cuda.cpp create mode 100644 fairseq/modules/lightconv_layer/lightconv_cuda.cuh create mode 100644 fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu create mode 100644 fairseq/modules/lightconv_layer/lightconv_layer.py create mode 100644 fairseq/modules/lightconv_layer/setup.py create mode 100644 fairseq/modules/lightweight_convolution.py create mode 100644 fairseq/modules/linearized_convolution.py create mode 100644 fairseq/modules/multihead_attention.py create mode 100644 fairseq/modules/positional_embedding.py create mode 100644 fairseq/modules/quant_noise.py create mode 100644 fairseq/modules/quantization/__init__.py create mode 100644 fairseq/modules/quantization/pq/__init__.py create mode 100644 fairseq/modules/quantization/pq/em.py create mode 100644 fairseq/modules/quantization/pq/modules/__init__.py create mode 100644 fairseq/modules/quantization/pq/modules/qconv.py create mode 100644 fairseq/modules/quantization/pq/modules/qemb.py create mode 100644 fairseq/modules/quantization/pq/modules/qlinear.py create mode 100644 fairseq/modules/quantization/pq/pq.py create mode 100644 fairseq/modules/quantization/pq/utils.py create mode 100644 fairseq/modules/quantization/quantization_options.py create mode 100644 fairseq/modules/quantization/scalar/__init__.py create mode 100644 fairseq/modules/quantization/scalar/modules/__init__.py create mode 100644 fairseq/modules/quantization/scalar/modules/qact.py create mode 100644 fairseq/modules/quantization/scalar/modules/qconv.py create mode 100644 fairseq/modules/quantization/scalar/modules/qemb.py create mode 100644 fairseq/modules/quantization/scalar/modules/qlinear.py create mode 100644 fairseq/modules/quantization/scalar/ops.py create mode 100644 fairseq/modules/quantization/scalar/utils.py create mode 100644 fairseq/modules/same_pad.py create mode 100644 fairseq/modules/scalar_bias.py create mode 100644 fairseq/modules/sinusoidal_positional_embedding.py create mode 100644 fairseq/modules/sparse_multihead_attention.py create mode 100644 fairseq/modules/sparse_transformer_sentence_encoder.py create mode 100644 fairseq/modules/sparse_transformer_sentence_encoder_layer.py create mode 100644 fairseq/modules/transformer_layer.py create mode 100644 fairseq/modules/transformer_sentence_encoder.py create mode 100644 fairseq/modules/transformer_sentence_encoder_layer.py create mode 100644 fairseq/modules/transpose_last.py create mode 100644 fairseq/modules/unfold.py create mode 100644 fairseq/modules/vggblock.py create mode 100644 fairseq/nan_detector.py create mode 100644 fairseq/optim/__init__.py create mode 100644 fairseq/optim/adadelta.py create mode 100644 fairseq/optim/adafactor.py create mode 100644 fairseq/optim/adagrad.py create mode 100644 fairseq/optim/adam.py create mode 100644 fairseq/optim/adamax.py create mode 100644 fairseq/optim/bmuf.py create mode 100644 fairseq/optim/dynamic_loss_scaler.py create mode 100644 fairseq/optim/fairseq_optimizer.py create mode 100644 fairseq/optim/fp16_optimizer.py create mode 100644 fairseq/optim/fused_adam.py create mode 100644 fairseq/optim/fused_lamb.py create mode 100644 fairseq/optim/lr_scheduler/__init__.py create mode 100644 fairseq/optim/lr_scheduler/cosine_lr_scheduler.py create mode 100644 fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py create mode 100644 fairseq/optim/lr_scheduler/fixed_schedule.py create mode 100644 fairseq/optim/lr_scheduler/inverse_square_root_schedule.py create mode 100644 fairseq/optim/lr_scheduler/polynomial_decay_schedule.py create mode 100644 fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py create mode 100644 fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py create mode 100644 fairseq/optim/lr_scheduler/triangular_lr_scheduler.py create mode 100644 fairseq/optim/nag.py create mode 100644 fairseq/optim/sgd.py create mode 100644 fairseq/optim/shard.py create mode 100644 fairseq/options.py create mode 100644 fairseq/pdb.py create mode 100644 fairseq/quantization_utils.py create mode 100644 fairseq/registry.py create mode 100644 fairseq/scoring/__init__.py create mode 100644 fairseq/scoring/bleu.py create mode 100644 fairseq/scoring/chrf.py create mode 100644 fairseq/scoring/tokenizer.py create mode 100644 fairseq/scoring/wer.py create mode 100644 fairseq/search.py create mode 100644 fairseq/sequence_generator.py create mode 100644 fairseq/sequence_scorer.py create mode 100644 fairseq/tasks/__init__.py create mode 100644 fairseq/tasks/audio_pretraining.py create mode 100644 fairseq/tasks/fairseq_task.py create mode 100644 fairseq/token_generation_constraints.py create mode 100644 fairseq/tokenizer.py create mode 100644 fairseq/trainer.py create mode 100644 fairseq/utils.py create mode 100644 fairseq/version.txt create mode 100644 fairseq_cli/__init__.py create mode 100644 fairseq_cli/eval_lm.py create mode 100644 fairseq_cli/generate.py create mode 100644 fairseq_cli/interactive.py create mode 100644 fairseq_cli/preprocess.py create mode 100644 fairseq_cli/score.py create mode 100644 fairseq_cli/train.py create mode 100644 fairseq_cli/validate.py create mode 100644 hubconf.py create mode 100644 pyproject.toml create mode 100644 scripts/__init__.py create mode 100644 scripts/average_checkpoints.py create mode 100644 scripts/build_sym_alignment.py create mode 100644 scripts/compare_namespaces.py create mode 100644 scripts/compound_split_bleu.sh create mode 100644 scripts/constraints/extract.py create mode 100644 scripts/constraints/validate.py create mode 100644 scripts/convert_dictionary.lua create mode 100644 scripts/convert_model.lua create mode 100644 scripts/count_docs.py create mode 100644 scripts/read_binarized.py create mode 100644 scripts/rm_pt.py create mode 100644 scripts/sacrebleu.sh create mode 100644 scripts/shard_docs.py create mode 100644 scripts/split_train_valid_docs.py create mode 100644 scripts/spm_decode.py create mode 100644 scripts/spm_encode.py create mode 100644 scripts/spm_train.py create mode 100644 setup.py create mode 100644 train.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..72fba99 --- /dev/null +++ b/LICENSE @@ -0,0 +1,74 @@ +Attribution-ShareAlike 3.0 Unported + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE LEGAL SERVICES. DISTRIBUTION OF THIS LICENSE DOES NOT CREATE AN ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE INFORMATION PROVIDED, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM ITS USE. + +License + +THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED. + +BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS LICENSE MAY BE CONSIDERED TO BE A CONTRACT, THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS. + +1. Definitions + + "Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original, except that a work that constitutes a Collection will not be considered an Adaptation for the purpose of this License. For the avoidance of doubt, where the Work is a musical work, performance or phonogram, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered an Adaptation for the purpose of this License. + "Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or performances, phonograms or broadcasts, or other works or subject matter other than works listed in Section 1(f) below, which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined below) for the purposes of this License. + "Creative Commons Compatible License" means a license that is listed at https://creativecommons.org/compatiblelicenses that has been approved by Creative Commons as being essentially equivalent to this License, including, at a minimum, because that license: (i) contains terms that have the same purpose, meaning and effect as the License Elements of this License; and, (ii) explicitly permits the relicensing of adaptations of works made available under that license under this License or a Creative Commons jurisdiction license with the same License Elements as this License. + "Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership. + "License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, ShareAlike. + "Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License. + "Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i) in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast. + "Work" means the literary and/or artistic work offered under the terms of this License including without limitation any production in the literary, scientific and artistic domain, whatever may be the mode or form of its expression including digital form, such as a book, pamphlet and other writing; a lecture, address, sermon or other work of the same nature; a dramatic or dramatico-musical work; a choreographic work or entertainment in dumb show; a musical composition with or without words; a cinematographic work to which are assimilated works expressed by a process analogous to cinematography; a work of drawing, painting, architecture, sculpture, engraving or lithography; a photographic work to which are assimilated works expressed by a process analogous to photography; a work of applied art; an illustration, map, plan, sketch or three-dimensional work relative to geography, topography, architecture or science; a performance; a broadcast; a phonogram; a compilation of data to the extent it is protected as a copyrightable work; or a work performed by a variety or circus performer to the extent it is not otherwise considered a literary or artistic work. + "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation. + "Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images. + "Reproduce" means to make copies of the Work by any means including without limitation by sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage of a protected performance or phonogram in digital form or other electronic medium. + +2. Fair Dealing Rights. Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws. + +3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below: + + to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections; + to create and Reproduce Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, a translation could be marked "The original work was translated from English to Spanish," or a modification could indicate "The original work has been modified."; + to Distribute and Publicly Perform the Work including as incorporated in Collections; and, + to Distribute and Publicly Perform Adaptations. + + For the avoidance of doubt: + Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; + Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor waives the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; and, + Voluntary License Schemes. The Licensor waives the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License. + +The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved. + +4. Restrictions. The license granted in Section 3 above is expressly made subject to and limited by the following restrictions: + + You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(c), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(c), as requested. + You may Distribute or Publicly Perform an Adaptation only under the terms of: (i) this License; (ii) a later version of this License with the same License Elements as this License; (iii) a Creative Commons jurisdiction license (either this or a later license version) that contains the same License Elements as this License (e.g., Attribution-ShareAlike 3.0 US)); (iv) a Creative Commons Compatible License. If you license the Adaptation under one of the licenses mentioned in (iv), you must comply with the terms of that license. If you license the Adaptation under the terms of any of the licenses mentioned in (i), (ii) or (iii) (the "Applicable License"), you must comply with the terms of the Applicable License generally and the following provisions: (I) You must include a copy of, or the URI for, the Applicable License with every copy of each Adaptation You Distribute or Publicly Perform; (II) You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License; (III) You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform; (IV) when You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License. + If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author (or pseudonym, if applicable) if supplied, and/or if the Original Author and/or Licensor designate another party or parties (e.g., a sponsor institute, publishing entity, journal) for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work if supplied; (iii) to the extent reasonably practicable, the URI, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and (iv) , consistent with Ssection 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). The credit required by this Section 4(c) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties. + Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise. + +5. Representations, Warranties and Disclaimer + +UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU. + +6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +7. Termination + + This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License. + Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above. + +8. Miscellaneous + + Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License. + Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License. + If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. + No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent. + This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You. + The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law. + + Creative Commons Notice + + Creative Commons is not a party to this License, and makes no warranty whatsoever in connection with the Work. Creative Commons will not be liable to You or any party on any legal theory for any damages whatsoever, including without limitation any general, special, incidental or consequential damages arising in connection to this license. Notwithstanding the foregoing two (2) sentences, if Creative Commons has expressly identified itself as the Licensor hereunder, it shall have all rights and obligations of Licensor. + + Except for the limited purpose of indicating to the public that the Work is licensed under the CCPL, Creative Commons does not authorize the use by either party of the trademark "Creative Commons" or any related trademark or logo of Creative Commons without the prior written consent of Creative Commons. Any permitted use will be in compliance with Creative Commons' then-current trademark usage guidelines, as may be published on its website or otherwise made available upon request from time to time. For the avoidance of doubt, this trademark restriction does not form part of the License. + + Creative Commons may be contacted at https://creativecommons.org/. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..86f7f22 --- /dev/null +++ b/README.md @@ -0,0 +1,58 @@ +# UniSpeech + +This is the official implementation of paper "[UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597)". The implementation mainly based on [fairseq](https://github.com/pytorch/fairseq) codebase. We release the training recipes on CommonVoice dataset. + +## Requirements and Installation + + - Pytorch >= 1.4.0 + - python version >= 3.6 + ``` bash + cd unispeech + pip install soundfile + pip install librosa + pip install pydub + pip install --editable ./ + ``` +## Data Preparation +Download pretraining audio data from [here](https://commonvoice.mozilla.org/datasets). (We use the June 2020 release version in our paper). +Get the wav list and the transcription for each dataset by run: +``` +python examples/unispeech/unispeech_manifest.py input_meta_file --dest examples/unispeech/data/LANG +``` + +Then convert the audio files in common voices to 16k HZ using the commond: +``` +python examples/unispeech/adjust_sample_rate.py --wav-path /path/to/wav/ --dest-path /path/to/16kwav/ --input examples/unispeech/data/LANG/*.tsv --output examples/unispeech/data/LANG/*_16k.tsv +``` +For the finetuning data, our train/val/test splits are following [this](https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz). +The phoneme transcriptions are generated by [phonemizer](https://github.com/bootphon/phonemizer) to convert texts to phonemes. Then we create .id files using different vocabularies. All our pre-processed data as well as the dictionaries can be downloaded from [here]. + +## Pretraining + +We give the training examples for large model here. +### Stage 1. Pretraining UniSpeech with labeled data. +The following script can be used to pre-train an English model: +``` +bash examples/unispeech/scripts/one2one_large_pretrain_en1350.sh +``` +To train a multilingual model: +``` +bash examples/unispeech/scripts/multilingual_large_pretrain.sh +``` + +### Stage 2. Continue pre-training with low-resource unlabeled data. (Optional) +After stage 1, you can continue pre-training the UniSpeech model with only contrastive loss: +``` +bash examples/unispeech/scripts/continue_pretran.sh +``` + +### Stage 3. Finetuning with low-resource labeled data. +Finally, fint-tune the model with 1 hour labeled data. +For multilingual models, you can choose to use separate vocabulary (examples/unispeech/data/en/vocab_sep.json) or shared vocabulary (examples/unispeech/data/en/vocab_share.json) +``` +bash examples/unispeech/scripts/finetune.sh +``` + +### Models +We release our pre-trained English large model and multilingual large model in stage 1 at [here]. You can finetune the model to get the reported results in our paper. + diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000..b9ee6c7 --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,111 @@ +# @package _group_ +common: + no_progress_bar: false + log_interval: 100 + log_format: null + tensorboard_logdir: null + seed: 1 + cpu: false + tpu: false + bf16: false + fp16: false + memory_efficient_fp16: false + memory_efficient_bf16: false + fp16_no_flatten_grads: false + fp16_init_scale: 128 + fp16_scale_window: null + fp16_scale_tolerance: 0.0 + min_loss_scale: 1.0e-4 + threshold_loss_scale: null + user_dir: null + empty_cache_freq: 0 + all_gather_list_size: 16384 + model_parallel_size: 1 + quantization_config_path: null + profile: false +distributed_training: + distributed_rank: 0 + distributed_backend: "nccl" + distributed_init_method: null + distributed_port: -1 + device_id: 0 + local_rank: 0 + distributed_no_spawn: false + ddp_backend: "c10d" + bucket_cap_mb: 25 + fix_batches_to_gpus: false + find_unused_parameters: false + fast_stat_sync: false + broadcast_buffers: false + distributed_wrapper: "DDP" + slowmo_momentum: null + slowmo_algorithm: "LocalSGD" + localsgd_frequency: 3 +dataset: + num_workers: 1 + skip_invalid_size_inputs_valid_test: false + max_tokens: null + batch_size: null + required_batch_size_multiple: 8 + dataset_impl: null + data_buffer_size: 10 + train_subset: "train" + valid_subset: "valid" + validate_interval: 1 + fixed_validation_seed: null + disable_validation: false + curriculum: 0 + gen_subset: "test" + num_shards: 1 + shard_id: 0 + max_tokens_valid: ${dataset.max_tokens} + batch_size_valid: ${dataset.batch_size} +optimization: + max_epoch: 0 + max_update: 0 + clip_norm: 25.0 + sentence_avg: false + update_freq: [ 1 ] + lr: [ 0.25 ] + min_lr: -1.0 + use_bmuf: false +checkpoint: + save_dir: "checkpoints" + restore_file: "checkpoint_last.pt" + reset_dataloader: false + reset_lr_scheduler: false + reset_meters: false + reset_optimizer: false + optimizer_overrides: "{}" + save_interval: 1 + save_interval_updates: 0 + keep_interval_updates: -1 + keep_last_epochs: -1 + keep_best_checkpoints: -1 + no_save: false + no_epoch_checkpoints: false + no_last_checkpoints: false + no_save_optimizer_state: false + best_checkpoint_metric: "loss" + maximize_best_checkpoint_metric: false + patience: -1 + checkpoint_suffix: "" +bmuf: + block_lr: 1 + block_momentum: 0.875 + global_sync_iter: 50 + warmup_iterations: 500 + use_nbm: false + average_sync: false +defaults: + - task: language_modeling + - model: null + - criterion: null + - optimizer: null + - lr_scheduler: null + - bpe: null + - tokenizer: null + - scoring: null + - generation: null + - common_eval: null + - eval_lm: null diff --git a/config/criterion/adaptive_loss.yaml b/config/criterion/adaptive_loss.yaml new file mode 100644 index 0000000..7997b07 --- /dev/null +++ b/config/criterion/adaptive_loss.yaml @@ -0,0 +1,3 @@ +# @package _group_ +sentence_avg: ${optimization.sentence_avg} +ddp_backend: ${distributed_training.ddp_backend} diff --git a/config/criterion/cross_entropy.yaml b/config/criterion/cross_entropy.yaml new file mode 100644 index 0000000..ad3d414 --- /dev/null +++ b/config/criterion/cross_entropy.yaml @@ -0,0 +1,2 @@ +# @package _group_ +sentence_avg: ${optimization.sentence_avg} diff --git a/config/lr_scheduler/cosine.yaml b/config/lr_scheduler/cosine.yaml new file mode 100644 index 0000000..0f91e0d --- /dev/null +++ b/config/lr_scheduler/cosine.yaml @@ -0,0 +1,7 @@ +# @package _group_ +warmup_updates: 0 +warmup_init_lr: -1 +max_lr: 1.0 +t_mult: 1.0 +lr_period_updates: -1 +lr_shrink: 0.1 diff --git a/config/lr_scheduler/inverse_sqrt.yaml b/config/lr_scheduler/inverse_sqrt.yaml new file mode 100644 index 0000000..0eac7d8 --- /dev/null +++ b/config/lr_scheduler/inverse_sqrt.yaml @@ -0,0 +1,3 @@ +# @package _group_ +warmup_updates: 4000 +warmup_init_lr: -1 diff --git a/config/model/transformer_lm.yaml b/config/model/transformer_lm.yaml new file mode 100644 index 0000000..3837ea5 --- /dev/null +++ b/config/model/transformer_lm.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.0 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 512 +decoder_output_dim: 512 +decoder_input_dim: 512 +decoder_ffn_embed_dim: 2048 +decoder_layers: 6 +decoder_attention_heads: 8 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_baevski_gbw.yaml b/config/model/transformer_lm_baevski_gbw.yaml new file mode 100644 index 0000000..30b1a4f --- /dev/null +++ b/config/model/transformer_lm_baevski_gbw.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 512 +decoder_output_dim: 512 +decoder_input_dim: 512 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_baevski_wiki103.yaml b/config/model/transformer_lm_baevski_wiki103.yaml new file mode 100644 index 0000000..1154cfa --- /dev/null +++ b/config/model/transformer_lm_baevski_wiki103.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.3 +attention_dropout: 0.1 +activation_dropout: 0.1 +relu_dropout: 0.1 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 16 +decoder_attention_heads: 8 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: "20000,60000" +adaptive_softmax_dropout: 0.2 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: true +adaptive_input_factor: 4 +adaptive_input_cutoff: "20000,60000" +tie_adaptive_weights: true +tie_adaptive_proj: true +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_big.yaml b/config/model/transformer_lm_big.yaml new file mode 100644 index 0000000..3095753 --- /dev/null +++ b/config/model/transformer_lm_big.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.0 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_gbw.yaml b/config/model/transformer_lm_gbw.yaml new file mode 100644 index 0000000..30b1a4f --- /dev/null +++ b/config/model/transformer_lm_gbw.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 512 +decoder_output_dim: 512 +decoder_input_dim: 512 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_gpt.yaml b/config/model/transformer_lm_gpt.yaml new file mode 100644 index 0000000..2c6cb7b --- /dev/null +++ b/config/model/transformer_lm_gpt.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 768 +decoder_output_dim: 768 +decoder_input_dim: 768 +decoder_ffn_embed_dim: 3072 +decoder_layers: 12 +decoder_attention_heads: 12 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_gpt2_big.yaml b/config/model/transformer_lm_gpt2_big.yaml new file mode 100644 index 0000000..a08769a --- /dev/null +++ b/config/model/transformer_lm_gpt2_big.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1600 +decoder_output_dim: 1600 +decoder_input_dim: 1600 +decoder_ffn_embed_dim: 6400 +decoder_layers: 48 +decoder_attention_heads: 25 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_gpt2_medium.yaml b/config/model/transformer_lm_gpt2_medium.yaml new file mode 100644 index 0000000..64261d7 --- /dev/null +++ b/config/model/transformer_lm_gpt2_medium.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1280 +decoder_output_dim: 1280 +decoder_input_dim: 1280 +decoder_ffn_embed_dim: 5120 +decoder_layers: 36 +decoder_attention_heads: 20 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_gpt2_small.yaml b/config/model/transformer_lm_gpt2_small.yaml new file mode 100644 index 0000000..702e81f --- /dev/null +++ b/config/model/transformer_lm_gpt2_small.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 24 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/model/transformer_lm_wiki103.yaml b/config/model/transformer_lm_wiki103.yaml new file mode 100644 index 0000000..1154cfa --- /dev/null +++ b/config/model/transformer_lm_wiki103.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.3 +attention_dropout: 0.1 +activation_dropout: 0.1 +relu_dropout: 0.1 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 16 +decoder_attention_heads: 8 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: "20000,60000" +adaptive_softmax_dropout: 0.2 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: true +adaptive_input_factor: 4 +adaptive_input_cutoff: "20000,60000" +tie_adaptive_weights: true +tie_adaptive_proj: true +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/config/optimizer/adam.yaml b/config/optimizer/adam.yaml new file mode 100644 index 0000000..e5264f8 --- /dev/null +++ b/config/optimizer/adam.yaml @@ -0,0 +1,5 @@ +# @package _group_ +adam_betas: "(0.9, 0.999)" +adam_eps: 1.0e-8 +weight_decay: 0 +use_old_adam: false diff --git a/config/optimizer/nag.yaml b/config/optimizer/nag.yaml new file mode 100644 index 0000000..4ab2745 --- /dev/null +++ b/config/optimizer/nag.yaml @@ -0,0 +1,3 @@ +# @package _group_ +momentum: 0.99 +weight_decay: 0.0 diff --git a/config/task/language_modeling.yaml b/config/task/language_modeling.yaml new file mode 100644 index 0000000..58a2ad1 --- /dev/null +++ b/config/task/language_modeling.yaml @@ -0,0 +1,10 @@ +# @package _group_ +data: ??? +sample_break_mode: "none" +tokens_per_sample: 1024 +output_dictionary_size: -1 +self_target: false +future_target: false +past_target: false +add_bos_token: false +max_target_positions: null diff --git a/examples/.gitignore b/examples/.gitignore new file mode 100644 index 0000000..1ef816f --- /dev/null +++ b/examples/.gitignore @@ -0,0 +1,2 @@ +!*/*.sh +!*/*.md diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..9e2c798 --- /dev/null +++ b/examples/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import __version__ # noqa diff --git a/examples/speech_recognition/README.md b/examples/speech_recognition/README.md new file mode 100644 index 0000000..19f7cc5 --- /dev/null +++ b/examples/speech_recognition/README.md @@ -0,0 +1,106 @@ +# Speech Recognition +`examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660). + + +## Additional dependencies +On top of main fairseq dependencies there are couple more additional requirements. + +1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features. +2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency. +3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets. + +## Preparing librispeech data +``` +./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA +``` + +## Training librispeech data +``` +python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/ +``` + +## Inference for librispeech +`$SET` can be `test_clean` or `test_other` +Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt` +``` +python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/ +``` + +## Inference for librispeech +``` +sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT +``` +`Sum/Avg` row from first table of the report has WER + +## Using wav2letter components +[wav2letter](https://github.com/facebookresearch/wav2letter) now has integration with fairseq. Currently this includes: + +* AutoSegmentationCriterion (ASG) +* wav2letter-style Conv/GLU model +* wav2letter's beam search decoder + +To use these, follow the instructions on [this page](https://github.com/facebookresearch/wav2letter/tree/master/bindings/python) to install python bindings. Please note that python bindings are for a *subset* of wav2letter and don't require its full dependencies (notably, `flashlight` and `ArrayFire` are *not* required). + +To quickly summarize the instructions: first, install [CUDA](https://developer.nvidia.com/cuda-downloads). Then follow these steps: +``` +# additional prerequisites - use equivalents for your distro +sudo apt-get install build-essential cmake libatlas-base-dev libfftw3-dev liblzma-dev libbz2-dev libzstd-dev +# install KenLM from source +git clone https://github.com/kpu/kenlm.git +cd kenlm +mkdir -p build && cd build +cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_POSITION_INDEPENDENT_CODE=ON +make -j16 +cd .. +export KENLM_ROOT_DIR=$(pwd) +cd .. +# install wav2letter python bindings +git clone https://github.com/facebookresearch/wav2letter.git +cd wav2letter/bindings/python +# make sure your python environment is active at this point +pip install torch packaging +pip install -e . +# try some examples to verify installation succeeded +python ./examples/criterion_example.py +python ./examples/decoder_example.py ../../src/decoder/test +python ./examples/feature_example.py ../../src/feature/test/data +``` + +## Training librispeech data (wav2letter style, Conv/GLU + ASG loss) +Training command: +``` +python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition +``` + +Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`. + +## Inference for librispeech (wav2letter decoder, n-gram LM) +Inference command: +``` +python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition +``` + +`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a wav2letter-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels): +``` +doorbell D O 1 R B E L 1 ▁ +``` +For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this): +``` +doorbell ▁DOOR BE LL +doorbell ▁DOOR B E LL +doorbell ▁DO OR BE LL +doorbell ▁DOOR B EL L +doorbell ▁DOOR BE L L +doorbell ▁DO OR B E LL +doorbell ▁DOOR B E L L +doorbell ▁DO OR B EL L +doorbell ▁DO O R BE LL +doorbell ▁DO OR BE L L +``` +Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`). + +## Inference for librispeech (wav2letter decoder, viterbi only) +Inference command: +``` +python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition +``` diff --git a/examples/speech_recognition/__init__.py b/examples/speech_recognition/__init__.py new file mode 100644 index 0000000..0278f6a --- /dev/null +++ b/examples/speech_recognition/__init__.py @@ -0,0 +1 @@ +from . import criterions, models, tasks # noqa diff --git a/examples/speech_recognition/criterions/ASG_loss.py b/examples/speech_recognition/criterions/ASG_loss.py new file mode 100644 index 0000000..7493654 --- /dev/null +++ b/examples/speech_recognition/criterions/ASG_loss.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from examples.speech_recognition.data.replabels import pack_replabels +from fairseq import utils +from fairseq.criterions import FairseqCriterion, register_criterion + + +@register_criterion("asg_loss") +class ASGCriterion(FairseqCriterion): + @staticmethod + def add_args(parser): + group = parser.add_argument_group("ASG Loss") + group.add_argument( + "--asg-transitions-init", + help="initial diagonal value of transition matrix", + type=float, + default=0.0, + ) + group.add_argument( + "--max-replabel", help="maximum # of replabels", type=int, default=2 + ) + group.add_argument( + "--linseg-updates", + help="# of training updates to use LinSeg initialization", + type=int, + default=0, + ) + group.add_argument( + "--hide-linseg-messages", + help="hide messages about LinSeg initialization", + action="store_true", + ) + + def __init__( + self, + task, + silence_token, + asg_transitions_init, + max_replabel, + linseg_updates, + hide_linseg_messages, + ): + from wav2letter.criterion import ASGLoss, CriterionScaleMode + + super().__init__(task) + self.tgt_dict = task.target_dictionary + self.eos = self.tgt_dict.eos() + self.silence = ( + self.tgt_dict.index(silence_token) + if silence_token in self.tgt_dict + else None + ) + self.max_replabel = max_replabel + + num_labels = len(self.tgt_dict) + self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT) + self.asg.trans = torch.nn.Parameter( + asg_transitions_init * torch.eye(num_labels), requires_grad=True + ) + + self.linseg_progress = torch.nn.Parameter( + torch.tensor([0], dtype=torch.int), requires_grad=False + ) + self.linseg_maximum = linseg_updates + self.linseg_message_state = "none" if hide_linseg_messages else "start" + + @classmethod + def build_criterion(cls, args, task): + return cls( + task, + args.silence_token, + args.asg_transitions_init, + args.max_replabel, + args.linseg_updates, + args.hide_linseg_messages, + ) + + def linseg_step(self): + if not self.training: + return False + if self.linseg_progress.item() < self.linseg_maximum: + if self.linseg_message_state == "start": + print("| using LinSeg to initialize ASG") + self.linseg_message_state = "finish" + self.linseg_progress.add_(1) + return True + elif self.linseg_message_state == "finish": + print("| finished LinSeg initialization") + self.linseg_message_state = "none" + return False + + def replace_eos_with_silence(self, tgt): + if tgt[-1] != self.eos: + return tgt + elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence): + return tgt[:-1] + else: + return tgt[:-1] + [self.silence] + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + + net_output = model(**sample["net_input"]) + emissions = net_output["encoder_out"].transpose(0, 1).contiguous() + B = emissions.size(0) + T = emissions.size(1) + device = emissions.device + + target = torch.IntTensor(B, T) + target_size = torch.IntTensor(B) + using_linseg = self.linseg_step() + + for b in range(B): + initial_target_size = sample["target_lengths"][b].item() + if initial_target_size == 0: + raise ValueError("target size cannot be zero") + + tgt = sample["target"][b, :initial_target_size].tolist() + tgt = self.replace_eos_with_silence(tgt) + tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel) + tgt = tgt[:T] + + if using_linseg: + tgt = [tgt[t * len(tgt) // T] for t in range(T)] + + target[b][: len(tgt)] = torch.IntTensor(tgt) + target_size[b] = len(tgt) + + loss = self.asg.forward(emissions, target.to(device), target_size.to(device)) + + if reduce: + loss = torch.sum(loss) + + sample_size = ( + sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + agg_output = { + "loss": loss_sum / nsentences, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + return agg_output diff --git a/examples/speech_recognition/criterions/__init__.py b/examples/speech_recognition/criterions/__init__.py new file mode 100644 index 0000000..88af9f3 --- /dev/null +++ b/examples/speech_recognition/criterions/__init__.py @@ -0,0 +1,17 @@ +import importlib +import os + + +# ASG loss requires wav2letter +files_to_skip = set() +try: + import wav2letter +except ImportError: + files_to_skip.add("ASG_loss.py") + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip: + criterion_name = file[: file.find(".py")] + importlib.import_module( + "examples.speech_recognition.criterions." + criterion_name + ) diff --git a/examples/speech_recognition/criterions/cross_entropy_acc.py b/examples/speech_recognition/criterions/cross_entropy_acc.py new file mode 100644 index 0000000..7c4d8ba --- /dev/null +++ b/examples/speech_recognition/criterions/cross_entropy_acc.py @@ -0,0 +1,130 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging +import math + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.criterions import FairseqCriterion, register_criterion + + +@register_criterion("cross_entropy_acc") +class CrossEntropyWithAccCriterion(FairseqCriterion): + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg + + def compute_loss(self, model, net_output, target, reduction, log_probs): + # N, T -> N * T + target = target.view(-1) + lprobs = model.get_normalized_probs(net_output, log_probs=log_probs) + if not hasattr(lprobs, "batch_first"): + logging.warning( + "ERROR: we need to know whether " + "batch first for the net output; " + "you need to set batch_first attribute for the return value of " + "model.get_normalized_probs. Now, we assume this is true, but " + "in the future, we will raise exception instead. " + ) + batch_first = getattr(lprobs, "batch_first", True) + if not batch_first: + lprobs = lprobs.transpose(0, 1) + + # N, T, D -> N * T, D + lprobs = lprobs.view(-1, lprobs.size(-1)) + loss = F.nll_loss( + lprobs, target, ignore_index=self.padding_idx, reduction=reduction + ) + return lprobs, loss + + def get_logging_output(self, sample, target, lprobs, loss): + target = target.view(-1) + mask = target != self.padding_idx + correct = torch.sum( + lprobs.argmax(1).masked_select(mask) == target.masked_select(mask) + ) + total = torch.sum(mask) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + + logging_output = { + "loss": utils.item(loss.data), # * sample['ntokens'], + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + "correct": utils.item(correct.data), + "total": utils.item(total.data), + "nframes": torch.sum(sample["net_input"]["src_lengths"]).item(), + } + + return sample_size, logging_output + + def forward(self, model, sample, reduction="sum", log_probs=True): + """Computes the cross entropy with accuracy metric for the given sample. + + This is similar to CrossEntropyCriterion in fairseq, but also + computes accuracy metrics as part of logging + + Args: + logprobs (Torch.tensor) of shape N, T, D i.e. + batchsize, timesteps, dimensions + targets (Torch.tensor) of shape N, T i.e batchsize, timesteps + + Returns: + tuple: With three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + + TODO: + * Currently this Criterion will only work with LSTMEncoderModels or + FairseqModels which have decoder, or Models which return TorchTensor + as net_output. + We need to make a change to support all FairseqEncoder models. + """ + net_output = model(**sample["net_input"]) + target = model.get_targets(sample, net_output) + lprobs, loss = self.compute_loss( + model, net_output, target, reduction, log_probs + ) + sample_size, logging_output = self.get_logging_output( + sample, target, lprobs, loss + ) + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + correct_sum = sum(log.get("correct", 0) for log in logging_outputs) + total_sum = sum(log.get("total", 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + nframes = sum(log.get("nframes", 0) for log in logging_outputs) + agg_output = { + "loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0, + # if args.sentence_avg, then sample_size is nsentences, then loss + # is per-sentence loss; else sample_size is ntokens, the loss + # becomes per-output token loss + "ntokens": ntokens, + "nsentences": nsentences, + "nframes": nframes, + "sample_size": sample_size, + "acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0, + "correct": correct_sum, + "total": total_sum, + # total is the number of validate tokens + } + if sample_size != ntokens: + agg_output["nll_loss"] = loss_sum / ntokens / math.log(2) + # loss: per output token loss + # nll_loss: per sentence loss + return agg_output diff --git a/examples/speech_recognition/data/__init__.py b/examples/speech_recognition/data/__init__.py new file mode 100644 index 0000000..47bb6e2 --- /dev/null +++ b/examples/speech_recognition/data/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .asr_dataset import AsrDataset + + +__all__ = [ + "AsrDataset", +] diff --git a/examples/speech_recognition/data/asr_dataset.py b/examples/speech_recognition/data/asr_dataset.py new file mode 100644 index 0000000..63a6fca --- /dev/null +++ b/examples/speech_recognition/data/asr_dataset.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import numpy as np +from fairseq.data import FairseqDataset + +from . import data_utils +from .collaters import Seq2SeqCollater + + +class AsrDataset(FairseqDataset): + """ + A dataset representing speech and corresponding transcription. + + Args: + aud_paths: (List[str]): A list of str with paths to audio files. + aud_durations_ms (List[int]): A list of int containing the durations of + audio files. + tgt (List[torch.LongTensor]): A list of LongTensors containing the indices + of target transcriptions. + tgt_dict (~fairseq.data.Dictionary): target vocabulary. + ids (List[str]): A list of utterance IDs. + speakers (List[str]): A list of speakers corresponding to utterances. + num_mel_bins (int): Number of triangular mel-frequency bins (default: 80) + frame_length (float): Frame length in milliseconds (default: 25.0) + frame_shift (float): Frame shift in milliseconds (default: 10.0) + """ + + def __init__( + self, + aud_paths, + aud_durations_ms, + tgt, + tgt_dict, + ids, + speakers, + num_mel_bins=80, + frame_length=25.0, + frame_shift=10.0, + ): + assert frame_length > 0 + assert frame_shift > 0 + assert all(x > frame_length for x in aud_durations_ms) + self.frame_sizes = [ + int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms + ] + + assert len(aud_paths) > 0 + assert len(aud_paths) == len(aud_durations_ms) + assert len(aud_paths) == len(tgt) + assert len(aud_paths) == len(ids) + assert len(aud_paths) == len(speakers) + self.aud_paths = aud_paths + self.tgt_dict = tgt_dict + self.tgt = tgt + self.ids = ids + self.speakers = speakers + self.num_mel_bins = num_mel_bins + self.frame_length = frame_length + self.frame_shift = frame_shift + + self.s2s_collater = Seq2SeqCollater( + 0, + 1, + pad_index=self.tgt_dict.pad(), + eos_index=self.tgt_dict.eos(), + move_eos_to_beginning=True, + ) + + def __getitem__(self, index): + import torchaudio + import torchaudio.compliance.kaldi as kaldi + + tgt_item = self.tgt[index] if self.tgt is not None else None + + path = self.aud_paths[index] + if not os.path.exists(path): + raise FileNotFoundError("Audio file not found: {}".format(path)) + sound, sample_rate = torchaudio.load_wav(path) + output = kaldi.fbank( + sound, + num_mel_bins=self.num_mel_bins, + frame_length=self.frame_length, + frame_shift=self.frame_shift, + ) + output_cmvn = data_utils.apply_mv_norm(output) + + return {"id": index, "data": [output_cmvn.detach(), tgt_item]} + + def __len__(self): + return len(self.aud_paths) + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[int]): sample indices to collate + + Returns: + dict: a mini-batch suitable for forwarding with a Model + """ + return self.s2s_collater.collate(samples) + + def num_tokens(self, index): + return self.frame_sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return ( + self.frame_sizes[index], + len(self.tgt[index]) if self.tgt is not None else 0, + ) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + return np.arange(len(self)) diff --git a/examples/speech_recognition/data/collaters.py b/examples/speech_recognition/data/collaters.py new file mode 100644 index 0000000..6acfec8 --- /dev/null +++ b/examples/speech_recognition/data/collaters.py @@ -0,0 +1,131 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" + This module contains collection of classes which implement + collate functionalities for various tasks. + + Collaters should know what data to expect for each sample + and they should pack / collate them into batches +""" + + +from __future__ import absolute_import, division, print_function, unicode_literals + +import numpy as np +import torch +from fairseq.data import data_utils as fairseq_data_utils + + +class Seq2SeqCollater(object): + """ + Implements collate function mainly for seq2seq tasks + This expects each sample to contain feature (src_tokens) and + targets. + This collator is also used for aligned training task. + """ + + def __init__( + self, + feature_index=0, + label_index=1, + pad_index=1, + eos_index=2, + move_eos_to_beginning=True, + ): + self.feature_index = feature_index + self.label_index = label_index + self.pad_index = pad_index + self.eos_index = eos_index + self.move_eos_to_beginning = move_eos_to_beginning + + def _collate_frames(self, frames): + """Convert a list of 2d frames into a padded 3d tensor + Args: + frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is + length of i-th frame and f_dim is static dimension of features + Returns: + 3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] + """ + len_max = max(frame.size(0) for frame in frames) + f_dim = frames[0].size(1) + res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0) + + for i, v in enumerate(frames): + res[i, : v.size(0)] = v + + return res + + def collate(self, samples): + """ + utility function to collate samples into batch for speech recognition. + """ + if len(samples) == 0: + return {} + + # parse samples into torch tensors + parsed_samples = [] + for s in samples: + # skip invalid samples + if s["data"][self.feature_index] is None: + continue + source = s["data"][self.feature_index] + if isinstance(source, (np.ndarray, np.generic)): + source = torch.from_numpy(source) + target = s["data"][self.label_index] + if isinstance(target, (np.ndarray, np.generic)): + target = torch.from_numpy(target).long() + elif isinstance(target, list): + target = torch.LongTensor(target) + + parsed_sample = {"id": s["id"], "source": source, "target": target} + parsed_samples.append(parsed_sample) + samples = parsed_samples + + id = torch.LongTensor([s["id"] for s in samples]) + frames = self._collate_frames([s["source"] for s in samples]) + # sort samples by descending number of frames + frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples]) + frames_lengths, sort_order = frames_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + frames = frames.index_select(0, sort_order) + + target = None + target_lengths = None + prev_output_tokens = None + if samples[0].get("target", None) is not None: + ntokens = sum(len(s["target"]) for s in samples) + target = fairseq_data_utils.collate_tokens( + [s["target"] for s in samples], + self.pad_index, + self.eos_index, + left_pad=False, + move_eos_to_beginning=False, + ) + target = target.index_select(0, sort_order) + target_lengths = torch.LongTensor( + [s["target"].size(0) for s in samples] + ).index_select(0, sort_order) + prev_output_tokens = fairseq_data_utils.collate_tokens( + [s["target"] for s in samples], + self.pad_index, + self.eos_index, + left_pad=False, + move_eos_to_beginning=self.move_eos_to_beginning, + ) + prev_output_tokens = prev_output_tokens.index_select(0, sort_order) + else: + ntokens = sum(len(s["source"]) for s in samples) + + batch = { + "id": id, + "ntokens": ntokens, + "net_input": {"src_tokens": frames, "src_lengths": frames_lengths}, + "target": target, + "target_lengths": target_lengths, + "nsentences": len(samples), + } + if prev_output_tokens is not None: + batch["net_input"]["prev_output_tokens"] = prev_output_tokens + return batch diff --git a/examples/speech_recognition/data/data_utils.py b/examples/speech_recognition/data/data_utils.py new file mode 100644 index 0000000..cc4729e --- /dev/null +++ b/examples/speech_recognition/data/data_utils.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def calc_mean_invstddev(feature): + if len(feature.size()) != 2: + raise ValueError("We expect the input feature to be 2-D tensor") + mean = feature.mean(0) + var = feature.var(0) + # avoid division by ~zero + eps = 1e-8 + if (var < eps).any(): + return mean, 1.0 / (torch.sqrt(var) + eps) + return mean, 1.0 / torch.sqrt(var) + + +def apply_mv_norm(features): + # If there is less than 2 spectrograms, the variance cannot be computed (is NaN) + # and normalization is not possible, so return the item as it is + if features.size(0) < 2: + return features + mean, invstddev = calc_mean_invstddev(features) + res = (features - mean) * invstddev + return res + + +def lengths_to_encoder_padding_mask(lengths, batch_first=False): + """ + convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor + + Args: + lengths: a (B, )-shaped tensor + + Return: + max_length: maximum length of B sequences + encoder_padding_mask: a (max_length, B) binary mask, where + [t, b] = 0 for t < lengths[b] and 1 otherwise + + TODO: + kernelize this function if benchmarking shows this function is slow + """ + max_lengths = torch.max(lengths).item() + bsz = lengths.size(0) + encoder_padding_mask = torch.arange( + max_lengths + ).to( # a (T, ) tensor with [0, ..., T-1] + lengths.device + ).view( # move to the right device + 1, max_lengths + ).expand( # reshape to (1, T)-shaped tensor + bsz, -1 + ) >= lengths.view( # expand to (B, T)-shaped tensor + bsz, 1 + ).expand( + -1, max_lengths + ) + if not batch_first: + return encoder_padding_mask.t(), max_lengths + else: + return encoder_padding_mask, max_lengths + + +def encoder_padding_mask_to_lengths( + encoder_padding_mask, max_lengths, batch_size, device +): + """ + convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor + + Conventionally, encoder output contains a encoder_padding_mask, which is + a 2-D mask in a shape (T, B), whose (t, b) element indicate whether + encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we + need to convert this mask tensor to a 1-D tensor in shape (B, ), where + [b] denotes the valid length of b-th sequence + + Args: + encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None, + indicating all are valid + Return: + seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the + number of valid elements of b-th sequence + + max_lengths: maximum length of all sequence, if encoder_padding_mask is + not None, max_lengths must equal to encoder_padding_mask.size(0) + + batch_size: batch size; if encoder_padding_mask is + not None, max_lengths must equal to encoder_padding_mask.size(1) + + device: which device to put the result on + """ + if encoder_padding_mask is None: + return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device) + + assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match" + assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match" + + return max_lengths - torch.sum(encoder_padding_mask, dim=0) diff --git a/examples/speech_recognition/data/replabels.py b/examples/speech_recognition/data/replabels.py new file mode 100644 index 0000000..d76bda7 --- /dev/null +++ b/examples/speech_recognition/data/replabels.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Replabel transforms for use with wav2letter's ASG criterion. +""" + + +def replabel_symbol(i): + """ + Replabel symbols used in wav2letter, currently just "1", "2", ... + This prevents training with numeral tokens, so this might change in the future + """ + return str(i) + + +def pack_replabels(tokens, dictionary, max_reps): + """ + Pack a token sequence so that repeated symbols are replaced by replabels + """ + if len(tokens) == 0 or max_reps <= 0: + return tokens + + replabel_value_to_idx = [0] * (max_reps + 1) + for i in range(1, max_reps + 1): + replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i)) + + result = [] + prev_token = -1 + num_reps = 0 + for token in tokens: + if token == prev_token and num_reps < max_reps: + num_reps += 1 + else: + if num_reps > 0: + result.append(replabel_value_to_idx[num_reps]) + num_reps = 0 + result.append(token) + prev_token = token + if num_reps > 0: + result.append(replabel_value_to_idx[num_reps]) + return result + + +def unpack_replabels(tokens, dictionary, max_reps): + """ + Unpack a token sequence so that replabels are replaced by repeated symbols + """ + if len(tokens) == 0 or max_reps <= 0: + return tokens + + replabel_idx_to_value = {} + for i in range(1, max_reps + 1): + replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i + + result = [] + prev_token = -1 + for token in tokens: + try: + for _ in range(replabel_idx_to_value[token]): + result.append(prev_token) + prev_token = -1 + except KeyError: + result.append(token) + prev_token = token + return result diff --git a/examples/speech_recognition/datasets/asr_prep_json.py b/examples/speech_recognition/datasets/asr_prep_json.py new file mode 100644 index 0000000..b8db8ff --- /dev/null +++ b/examples/speech_recognition/datasets/asr_prep_json.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import concurrent.futures +import json +import multiprocessing +import os +from collections import namedtuple +from itertools import chain + +import sentencepiece as spm +from fairseq.data import Dictionary + + +MILLISECONDS_TO_SECONDS = 0.001 + + +def process_sample(aud_path, lable, utt_id, sp, tgt_dict): + import torchaudio + + input = {} + output = {} + si, ei = torchaudio.info(aud_path) + input["length_ms"] = int( + si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS + ) + input["path"] = aud_path + + token = " ".join(sp.EncodeAsPieces(lable)) + ids = tgt_dict.encode_line(token, append_eos=False) + output["text"] = lable + output["token"] = token + output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids])) + return {utt_id: {"input": input, "output": output}} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--audio-dirs", + nargs="+", + default=["-"], + required=True, + help="input directories with audio files", + ) + parser.add_argument( + "--labels", + required=True, + help="aggregated input labels with format per line", + type=argparse.FileType("r", encoding="UTF-8"), + ) + parser.add_argument( + "--spm-model", + required=True, + help="sentencepiece model to use for encoding", + type=argparse.FileType("r", encoding="UTF-8"), + ) + parser.add_argument( + "--dictionary", + required=True, + help="file to load fairseq dictionary from", + type=argparse.FileType("r", encoding="UTF-8"), + ) + parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav") + parser.add_argument( + "--output", + required=True, + type=argparse.FileType("w"), + help="path to save json output", + ) + args = parser.parse_args() + + sp = spm.SentencePieceProcessor() + sp.Load(args.spm_model.name) + + tgt_dict = Dictionary.load(args.dictionary) + + labels = {} + for line in args.labels: + (utt_id, label) = line.split(" ", 1) + labels[utt_id] = label + if len(labels) == 0: + raise Exception("No labels found in ", args.labels_path) + + Sample = namedtuple("Sample", "aud_path utt_id") + samples = [] + for path, _, files in chain.from_iterable( + os.walk(path) for path in args.audio_dirs + ): + for f in files: + if f.endswith(args.audio_format): + if len(os.path.splitext(f)) != 2: + raise Exception("Expect file name. Got: ", f) + utt_id = os.path.splitext(f)[0] + if utt_id not in labels: + continue + samples.append(Sample(os.path.join(path, f), utt_id)) + + utts = {} + num_cpu = multiprocessing.cpu_count() + with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor: + future_to_sample = { + executor.submit( + process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict + ): s + for s in samples + } + for future in concurrent.futures.as_completed(future_to_sample): + try: + data = future.result() + except Exception as exc: + print("generated an exception: ", exc) + else: + utts.update(data) + json.dump({"utts": utts}, args.output, indent=4) + + +if __name__ == "__main__": + main() diff --git a/examples/speech_recognition/datasets/prepare-librispeech.sh b/examples/speech_recognition/datasets/prepare-librispeech.sh new file mode 100644 index 0000000..9e9297f --- /dev/null +++ b/examples/speech_recognition/datasets/prepare-librispeech.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env bash +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Prepare librispeech dataset + +base_url=www.openslr.org/resources/12 +train_dir=train_960 + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final" + exit 1 +fi + +download_dir=${1%/} +out_dir=${2%/} + +fairseq_root=~/fairseq-py/ +mkdir -p ${out_dir} +cd ${out_dir} || exit + +nbpe=5000 +bpemode=unigram + +if [ ! -d "$fairseq_root" ]; then + echo "$0: Please set correct fairseq_root" + exit 1 +fi + +echo "Data Download" +for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do + url=$base_url/$part.tar.gz + if ! wget -P $download_dir $url; then + echo "$0: wget failed for $url" + exit 1 + fi + if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then + echo "$0: error un-tarring archive $download_dir/$part.tar.gz" + exit 1 + fi +done + +echo "Merge all train packs into one" +mkdir -p ${download_dir}/LibriSpeech/${train_dir}/ +for part in train-clean-100 train-clean-360 train-other-500; do + mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/ +done +echo "Merge train text" +find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text + +# Use combined dev-clean and dev-other as validation set +find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text +find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text +find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text + + +dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt +encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt +fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt +bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe} +echo "dictionary: ${dict}" +echo "Dictionary preparation" +mkdir -p data/lang_char/ +echo " 3" > ${dict} +echo " 2" >> ${dict} +echo " 1" >> ${dict} +cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt +spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1 +spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded} +cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict} +cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict} +wc -l ${dict} + +echo "Prepare train and test jsons" +for part in train_960 test-other test-clean; do + python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json +done +# fairseq expects to find train.json and valid.json during training +mv train_960.json train.json + +echo "Prepare valid json" +python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json + +cp ${fairseq_dict} ./dict.txt +cp ${bpemodel}.model ./spm.model diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py new file mode 100644 index 0000000..8e91d31 --- /dev/null +++ b/examples/speech_recognition/infer.py @@ -0,0 +1,474 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run inference for pre-processed data with a trained model. +""" + +import logging +import math +import os +import sys +if '/home/fairseq' in sys.path: + sys.path.remove('/home/fairseq') +sys.path.append('/datablob/users/v-chengw/pyproj/fairseq') + +import editdistance +import numpy as np +import torch +from fairseq import checkpoint_utils, options, progress_bar, tasks, utils, pdb +from fairseq.data.data_utils import post_process +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging.meters import StopwatchMeter, TimeMeter + + +logging.basicConfig() +logging.root.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def add_asr_eval_argument(parser): + parser.add_argument("--kspmodel", default=None, help="sentence piece model") + parser.add_argument( + "--wfstlm", default=None, help="wfstlm on dictonary output units" + ) + parser.add_argument( + "--rnnt_decoding_type", + default="greedy", + help="wfstlm on dictonary\ +output units", + ) + try: + parser.add_argument( + "--lm-weight", + "--lm_weight", + type=float, + default=0.2, + help="weight for lm while interpolating with neural score", + ) + except: + pass + parser.add_argument( + "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level" + ) + parser.add_argument( + "--w2l-decoder", + choices=["viterbi", "kenlm", "fairseqlm"], + help="use a w2l decoder", + ) + parser.add_argument("--lexicon", help="lexicon for w2l decoder") + parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm") + parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder") + parser.add_argument("--beam-threshold", type=float, default=25.0) + parser.add_argument("--beam-size-token", type=float, default=100) + parser.add_argument("--word-score", type=float, default=1.0) + parser.add_argument("--unk-weight", type=float, default=-math.inf) + parser.add_argument("--sil-weight", type=float, default=0.0) + parser.add_argument( + "--dump-emissions", + type=str, + default=None, + help="if present, dumps emissions into this file and exits", + ) + parser.add_argument( + "--dump-features", + type=str, + default=None, + help="if present, dumps features into this file and exits", + ) + parser.add_argument( + "--load-emissions", + type=str, + default=None, + help="if present, loads emissions from this file", + ) + return parser + + +def check_args(args): + # assert args.path is not None, "--path required for generation!" + # assert args.results_path is not None, "--results_path required for generation!" + assert ( + not args.sampling or args.nbest == args.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + args.replace_unk is None or args.raw_text + ), "--replace-unk requires a raw text dataset (--raw-text)" + + +def get_dataset_itr(args, task, models): + return task.get_batch_iterator( + dataset=task.dataset(args.gen_subset), + max_tokens=args.max_tokens, + max_sentences=args.batch_size, + max_positions=(sys.maxsize, sys.maxsize), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=args.required_batch_size_multiple, + num_shards=args.num_shards, + shard_id=args.shard_id, + num_workers=args.num_workers, + data_buffer_size=args.data_buffer_size, + ).next_epoch_itr(shuffle=False) + + +def process_predictions( + args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id +): + for hypo in hypos[: min(len(hypos), args.nbest)]: + hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) + + if "words" in hypo: + hyp_words = " ".join(hypo["words"]) + else: + hyp_words = post_process(hyp_pieces, args.post_process) + + if res_files is not None: + print( + "{} ({}-{})".format(hyp_pieces, speaker, id), + file=res_files["hypo.units"], + ) + print( + "{} ({}-{})".format(hyp_words, speaker, id), + file=res_files["hypo.words"], + ) + + tgt_pieces = tgt_dict.string(target_tokens) + tgt_words = post_process(tgt_pieces, args.post_process) + + if res_files is not None: + print( + "{} ({}-{})".format(tgt_pieces, speaker, id), + file=res_files["ref.units"], + ) + print( + "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"] + ) + # only score top hypothesis + if not args.quiet: + logger.debug("HYPO:" + hyp_words) + logger.debug("TARGET:" + tgt_words) + logger.debug("___________________") + + hyp_words = hyp_words.split() + tgt_words = tgt_words.split() + return editdistance.eval(hyp_words, tgt_words), len(tgt_words) + + +def prepare_result_files(args): + def get_res_file(file_prefix): + if args.num_shards > 1: + file_prefix = f"{args.shard_id}_{file_prefix}" + path = os.path.join( + args.results_path, + "{}-{}-{}.txt".format( + file_prefix, os.path.basename(args.path), args.gen_subset + ), + ) + return open(path, "w", buffering=1) + + if not args.results_path: + return None + + return { + "hypo.words": get_res_file("hypo.word"), + "hypo.units": get_res_file("hypo.units"), + "ref.words": get_res_file("ref.word"), + "ref.units": get_res_file("ref.units"), + } + + +def load_models_and_criterions( + filenames, data_path, arg_overrides=None, task=None, model_state=None +): + models = [] + criterions = [] + + if arg_overrides is None: + arg_overrides = {} + + arg_overrides["wer_args"] = None + arg_overrides["data"] = data_path + + if filenames is None: + assert model_state is not None + filenames = [0] + else: + filenames = filenames.split(":") + + for filename in filenames: + if model_state is None: + if not os.path.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides) + else: + state = model_state + + if "cfg" in state: + cfg = state["cfg"] + else: + cfg = convert_namespace_to_omegaconf(state["args"]) + + if task is None: + if hasattr(cfg.task, 'data'): + cfg.task.data = data_path + task = tasks.setup_task(cfg.task) + + model = task.build_model(cfg.model) + model.load_state_dict(state["model"], strict=True) + models.append(model) + + criterion = task.build_criterion(cfg.criterion) + if "criterion" in state: + criterion.load_state_dict(state["criterion"], strict=True) + criterions.append(criterion) + return models, criterions, task + + +def optimize_models(args, use_cuda, models): + """Optimize ensemble for generation""" + for model in models: + model.make_generation_fast_( + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, + need_attn=args.print_alignment, + ) + if args.fp16: + model.half() + if use_cuda: + model.cuda() + + +class ExistingEmissionsDecoder(object): + def __init__(self, decoder, emissions): + self.decoder = decoder + self.emissions = emissions + + def generate(self, models, sample, **unused): + ids = sample["id"].cpu().numpy() + try: + emissions = np.stack(self.emissions[ids]) + except: + print([x.shape for x in self.emissions[ids]]) + raise Exception("invalid sizes") + emissions = torch.from_numpy(emissions) + return self.decoder.decode(emissions) + + +def main(args, task=None, model_state=None): + check_args(args) + + if args.max_tokens is None and args.batch_size is None: + args.max_tokens = 4000000 + logger.info(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + + + logger.info("| decoding with criterion {}".format(args.criterion)) + + # Load ensemble + if args.load_emissions: + models, criterions = [], [] + task = tasks.setup_task(args) + else: + logger.info("| loading model(s) from {}".format(args.path)) + models, criterions, task = load_models_and_criterions( + args.path, + data_path=args.data, + arg_overrides=eval(args.model_overrides), # noqa + task=task, + model_state=model_state, + ) + optimize_models(args, use_cuda, models) + + # Load dataset splits + task.load_dataset(args.gen_subset) + + # Set dictionary + tgt_dict = task.target_dictionary + + logger.info( + "| {} {} {} examples".format( + args.data, args.gen_subset, len(task.dataset(args.gen_subset)) + ) + ) + + # hack to pass transitions to W2lDecoder + if args.criterion == "asg_loss": + trans = criterions[0].asg.trans.data + args.asg_transitions = torch.flatten(trans).tolist() + + # Load dataset (possibly sharded) + itr = get_dataset_itr(args, task, models) + + # Initialize generator + gen_timer = StopwatchMeter() + + def build_generator(args): + w2l_decoder = getattr(args, "w2l_decoder", None) + if w2l_decoder == "viterbi": + from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder + + return W2lViterbiDecoder(args, task.target_dictionary) + elif w2l_decoder == "kenlm": + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + return W2lKenLMDecoder(args, task.target_dictionary) + elif w2l_decoder == "fairseqlm": + from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder + + return W2lFairseqLMDecoder(args, task.target_dictionary) + else: + print( + "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment" + ) + + # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task + generator = build_generator(args) + + if args.load_emissions: + generator = ExistingEmissionsDecoder( + generator, np.load(args.load_emissions, allow_pickle=True) + ) + logger.info("loaded emissions from " + args.load_emissions) + + num_sentences = 0 + + if args.results_path is not None and not os.path.exists(args.results_path): + os.makedirs(args.results_path) + + max_source_pos = ( + utils.resolve_max_positions( + task.max_positions(), *[model.max_positions() for model in models] + ), + ) + + if max_source_pos is not None: + max_source_pos = max_source_pos[0] + if max_source_pos is not None: + max_source_pos = max_source_pos[0] - 1 + + if args.dump_emissions: + emissions = {} + if args.dump_features: + features = {} + models[0].bert.proj = None + else: + res_files = prepare_result_files(args) + errs_t = 0 + lengths_t = 0 + with progress_bar.build_progress_bar(args, itr) as t: + wps_meter = TimeMeter() + for sample in t: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if "net_input" not in sample: + continue + + prefix_tokens = None + if args.prefix_size > 0: + prefix_tokens = sample["target"][:, : args.prefix_size] + + gen_timer.start() + if args.dump_emissions: + with torch.no_grad(): + encoder_out = models[0](**sample["net_input"]) + emm = models[0].get_normalized_probs(encoder_out, log_probs=True) + emm = emm.transpose(0, 1).cpu().numpy() + for i, id in enumerate(sample["id"]): + emissions[id.item()] = emm[i] + continue + elif args.dump_features: + with torch.no_grad(): + encoder_out = models[0](**sample["net_input"]) + feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy() + for i, id in enumerate(sample["id"]): + padding = ( + encoder_out["encoder_padding_mask"][i].cpu().numpy() + if encoder_out["encoder_padding_mask"] is not None + else None + ) + features[id.item()] = (feat[i], padding) + continue + hypos = task.inference_step(generator, models, sample, prefix_tokens) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) + gen_timer.stop(num_generated_tokens) + + for i, sample_id in enumerate(sample["id"].tolist()): + speaker = None + # id = task.dataset(args.gen_subset).ids[int(sample_id)] + id = sample_id + toks = ( + sample["target"][i, :] + if "target_label" not in sample + else sample["target_label"][i, :] + ) + target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu() + # Process top predictions + errs, length = process_predictions( + args, + hypos[i], + None, + tgt_dict, + target_tokens, + res_files, + speaker, + id, + ) + errs_t += errs + lengths_t += length + + wps_meter.update(num_generated_tokens) + t.log({"wps": round(wps_meter.avg)}) + num_sentences += ( + sample["nsentences"] if "nsentences" in sample else sample["id"].numel() + ) + + wer = None + if args.dump_emissions: + emm_arr = [] + for i in range(len(emissions)): + emm_arr.append(emissions[i]) + np.save(args.dump_emissions, emm_arr) + logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}") + elif args.dump_features: + feat_arr = [] + for i in range(len(features)): + feat_arr.append(features[i]) + np.save(args.dump_features, feat_arr) + logger.info(f"saved {len(features)} emissions to {args.dump_features}") + else: + if lengths_t > 0: + wer = errs_t * 100.0 / lengths_t + logger.info(f"WER: {wer}") + + logger.info( + "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}" + "sentences/s, {:.2f} tokens/s)".format( + num_sentences, + gen_timer.n, + gen_timer.sum, + num_sentences / gen_timer.sum, + 1.0 / gen_timer.avg, + ) + ) + logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam)) + return task, wer + + +def make_parser(): + parser = options.get_generation_parser() + parser = add_asr_eval_argument(parser) + return parser + + +def cli_main(): + parser = make_parser() + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/examples/speech_recognition/models/__init__.py b/examples/speech_recognition/models/__init__.py new file mode 100644 index 0000000..0ad9663 --- /dev/null +++ b/examples/speech_recognition/models/__init__.py @@ -0,0 +1,8 @@ +import importlib +import os + + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + model_name = file[: file.find(".py")] + importlib.import_module("examples.speech_recognition.models." + model_name) diff --git a/examples/speech_recognition/models/vggtransformer.py b/examples/speech_recognition/models/vggtransformer.py new file mode 100644 index 0000000..9797436 --- /dev/null +++ b/examples/speech_recognition/models/vggtransformer.py @@ -0,0 +1,1019 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import math +from collections.abc import Iterable + +import torch +import torch.nn as nn +from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqEncoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from fairseq.modules import ( + LinearizedConvolution, + TransformerDecoderLayer, + TransformerEncoderLayer, + VGGBlock, +) + + +@register_model("asr_vggtransformer") +class VGGTransformerModel(FairseqEncoderDecoderModel): + """ + Transformers with convolutional context for ASR + https://arxiv.org/abs/1904.11660 + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--input-feat-per-channel", + type=int, + metavar="N", + help="encoder input dimension per input channel", + ) + parser.add_argument( + "--vggblock-enc-config", + type=str, + metavar="EXPR", + help=""" + an array of tuples each containing the configuration of one vggblock: + [(out_channels, + conv_kernel_size, + pooling_kernel_size, + num_conv_layers, + use_layer_norm), ...]) + """, + ) + parser.add_argument( + "--transformer-enc-config", + type=str, + metavar="EXPR", + help="""" + a tuple containing the configuration of the encoder transformer layers + configurations: + [(input_dim, + num_heads, + ffn_dim, + normalize_before, + dropout, + attention_dropout, + relu_dropout), ...]') + """, + ) + parser.add_argument( + "--enc-output-dim", + type=int, + metavar="N", + help=""" + encoder output dimension, can be None. If specified, projecting the + transformer output to the specified dimension""", + ) + parser.add_argument( + "--in-channels", + type=int, + metavar="N", + help="number of encoder input channels", + ) + parser.add_argument( + "--tgt-embed-dim", + type=int, + metavar="N", + help="embedding dimension of the decoder target tokens", + ) + parser.add_argument( + "--transformer-dec-config", + type=str, + metavar="EXPR", + help=""" + a tuple containing the configuration of the decoder transformer layers + configurations: + [(input_dim, + num_heads, + ffn_dim, + normalize_before, + dropout, + attention_dropout, + relu_dropout), ...] + """, + ) + parser.add_argument( + "--conv-dec-config", + type=str, + metavar="EXPR", + help=""" + an array of tuples for the decoder 1-D convolution config + [(out_channels, conv_kernel_size, use_layer_norm), ...]""", + ) + + @classmethod + def build_encoder(cls, args, task): + return VGGTransformerEncoder( + input_feat_per_channel=args.input_feat_per_channel, + vggblock_config=eval(args.vggblock_enc_config), + transformer_config=eval(args.transformer_enc_config), + encoder_output_dim=args.enc_output_dim, + in_channels=args.in_channels, + ) + + @classmethod + def build_decoder(cls, args, task): + return TransformerDecoder( + dictionary=task.target_dictionary, + embed_dim=args.tgt_embed_dim, + transformer_config=eval(args.transformer_dec_config), + conv_config=eval(args.conv_dec_config), + encoder_output_dim=args.enc_output_dim, + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted + # (in case there are any new ones) + base_architecture(args) + + encoder = cls.build_encoder(args, task) + decoder = cls.build_decoder(args, task) + return cls(encoder, decoder) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = super().get_normalized_probs(net_output, log_probs, sample) + lprobs.batch_first = True + return lprobs + + +DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2 +DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2 +# 256: embedding dimension +# 4: number of heads +# 1024: FFN +# True: apply layerNorm before (dropout + resiaul) instead of after +# 0.2 (dropout): dropout after MultiheadAttention and second FC +# 0.2 (attention_dropout): dropout in MultiheadAttention +# 0.2 (relu_dropout): dropout after ReLu +DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2 +DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2 + + +# TODO: repace transformer encoder config from one liner +# to explicit args to get rid of this transformation +def prepare_transformer_encoder_params( + input_dim, + num_heads, + ffn_dim, + normalize_before, + dropout, + attention_dropout, + relu_dropout, +): + args = argparse.Namespace() + args.encoder_embed_dim = input_dim + args.encoder_attention_heads = num_heads + args.attention_dropout = attention_dropout + args.dropout = dropout + args.activation_dropout = relu_dropout + args.encoder_normalize_before = normalize_before + args.encoder_ffn_embed_dim = ffn_dim + return args + + +def prepare_transformer_decoder_params( + input_dim, + num_heads, + ffn_dim, + normalize_before, + dropout, + attention_dropout, + relu_dropout, +): + args = argparse.Namespace() + args.decoder_embed_dim = input_dim + args.decoder_attention_heads = num_heads + args.attention_dropout = attention_dropout + args.dropout = dropout + args.activation_dropout = relu_dropout + args.decoder_normalize_before = normalize_before + args.decoder_ffn_embed_dim = ffn_dim + return args + + +class VGGTransformerEncoder(FairseqEncoder): + """VGG + Transformer encoder""" + + def __init__( + self, + input_feat_per_channel, + vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG, + transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, + encoder_output_dim=512, + in_channels=1, + transformer_context=None, + transformer_sampling=None, + ): + """constructor for VGGTransformerEncoder + + Args: + - input_feat_per_channel: feature dim (not including stacked, + just base feature) + - in_channel: # input channels (e.g., if stack 8 feature vector + together, this is 8) + - vggblock_config: configuration of vggblock, see comments on + DEFAULT_ENC_VGGBLOCK_CONFIG + - transformer_config: configuration of transformer layer, see comments + on DEFAULT_ENC_TRANSFORMER_CONFIG + - encoder_output_dim: final transformer output embedding dimension + - transformer_context: (left, right) if set, self-attention will be focused + on (t-left, t+right) + - transformer_sampling: an iterable of int, must match with + len(transformer_config), transformer_sampling[i] indicates sampling + factor for i-th transformer layer, after multihead att and feedfoward + part + """ + super().__init__(None) + + self.num_vggblocks = 0 + if vggblock_config is not None: + if not isinstance(vggblock_config, Iterable): + raise ValueError("vggblock_config is not iterable") + self.num_vggblocks = len(vggblock_config) + + self.conv_layers = nn.ModuleList() + self.in_channels = in_channels + self.input_dim = input_feat_per_channel + self.pooling_kernel_sizes = [] + + if vggblock_config is not None: + for _, config in enumerate(vggblock_config): + ( + out_channels, + conv_kernel_size, + pooling_kernel_size, + num_conv_layers, + layer_norm, + ) = config + self.conv_layers.append( + VGGBlock( + in_channels, + out_channels, + conv_kernel_size, + pooling_kernel_size, + num_conv_layers, + input_dim=input_feat_per_channel, + layer_norm=layer_norm, + ) + ) + self.pooling_kernel_sizes.append(pooling_kernel_size) + in_channels = out_channels + input_feat_per_channel = self.conv_layers[-1].output_dim + + transformer_input_dim = self.infer_conv_output_dim( + self.in_channels, self.input_dim + ) + # transformer_input_dim is the output dimension of VGG part + + self.validate_transformer_config(transformer_config) + self.transformer_context = self.parse_transformer_context(transformer_context) + self.transformer_sampling = self.parse_transformer_sampling( + transformer_sampling, len(transformer_config) + ) + + self.transformer_layers = nn.ModuleList() + + if transformer_input_dim != transformer_config[0][0]: + self.transformer_layers.append( + Linear(transformer_input_dim, transformer_config[0][0]) + ) + self.transformer_layers.append( + TransformerEncoderLayer( + prepare_transformer_encoder_params(*transformer_config[0]) + ) + ) + + for i in range(1, len(transformer_config)): + if transformer_config[i - 1][0] != transformer_config[i][0]: + self.transformer_layers.append( + Linear(transformer_config[i - 1][0], transformer_config[i][0]) + ) + self.transformer_layers.append( + TransformerEncoderLayer( + prepare_transformer_encoder_params(*transformer_config[i]) + ) + ) + + self.encoder_output_dim = encoder_output_dim + self.transformer_layers.extend( + [ + Linear(transformer_config[-1][0], encoder_output_dim), + LayerNorm(encoder_output_dim), + ] + ) + + def forward(self, src_tokens, src_lengths, **kwargs): + """ + src_tokens: padded tensor (B, T, C * feat) + src_lengths: tensor of original lengths of input utterances (B,) + """ + bsz, max_seq_len, _ = src_tokens.size() + x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + x = x.transpose(1, 2).contiguous() + # (B, C, T, feat) + + for layer_idx in range(len(self.conv_layers)): + x = self.conv_layers[layer_idx](x) + + bsz, _, output_seq_len, _ = x.size() + + # (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat) + x = x.transpose(1, 2).transpose(0, 1) + x = x.contiguous().view(output_seq_len, bsz, -1) + + input_lengths = src_lengths.clone() + for s in self.pooling_kernel_sizes: + input_lengths = (input_lengths.float() / s).ceil().long() + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + input_lengths, batch_first=True + ) + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) + attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor) + + transformer_layer_idx = 0 + + for layer_idx in range(len(self.transformer_layers)): + + if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer): + x = self.transformer_layers[layer_idx]( + x, encoder_padding_mask, attn_mask + ) + + if self.transformer_sampling[transformer_layer_idx] != 1: + sampling_factor = self.transformer_sampling[transformer_layer_idx] + x, encoder_padding_mask, attn_mask = self.slice( + x, encoder_padding_mask, attn_mask, sampling_factor + ) + + transformer_layer_idx += 1 + + else: + x = self.transformer_layers[layer_idx](x) + + # encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate + # whether encoder_output[t, b] is valid or not (valid=0, invalid=1) + + return { + "encoder_out": x, # (T, B, C) + "encoder_padding_mask": encoder_padding_mask.t() + if encoder_padding_mask is not None + else None, + # (B, T) --> (T, B) + } + + def infer_conv_output_dim(self, in_channels, input_dim): + sample_seq_len = 200 + sample_bsz = 10 + x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) + for i, _ in enumerate(self.conv_layers): + x = self.conv_layers[i](x) + x = x.transpose(1, 2) + mb, seq = x.size()[:2] + return x.contiguous().view(mb, seq, -1).size(-1) + + def validate_transformer_config(self, transformer_config): + for config in transformer_config: + input_dim, num_heads = config[:2] + if input_dim % num_heads != 0: + msg = ( + "ERROR in transformer config {}: ".format(config) + + "input dimension {} ".format(input_dim) + + "not dividable by number of heads {}".format(num_heads) + ) + raise ValueError(msg) + + def parse_transformer_context(self, transformer_context): + """ + transformer_context can be the following: + - None; indicates no context is used, i.e., + transformer can access full context + - a tuple/list of two int; indicates left and right context, + any number <0 indicates infinite context + * e.g., (5, 6) indicates that for query at x_t, transformer can + access [t-5, t+6] (inclusive) + * e.g., (-1, 6) indicates that for query at x_t, transformer can + access [0, t+6] (inclusive) + """ + if transformer_context is None: + return None + + if not isinstance(transformer_context, Iterable): + raise ValueError("transformer context must be Iterable if it is not None") + + if len(transformer_context) != 2: + raise ValueError("transformer context must have length 2") + + left_context = transformer_context[0] + if left_context < 0: + left_context = None + + right_context = transformer_context[1] + if right_context < 0: + right_context = None + + if left_context is None and right_context is None: + return None + + return (left_context, right_context) + + def parse_transformer_sampling(self, transformer_sampling, num_layers): + """ + parsing transformer sampling configuration + + Args: + - transformer_sampling, accepted input: + * None, indicating no sampling + * an Iterable with int (>0) as element + - num_layers, expected number of transformer layers, must match with + the length of transformer_sampling if it is not None + + Returns: + - A tuple with length num_layers + """ + if transformer_sampling is None: + return (1,) * num_layers + + if not isinstance(transformer_sampling, Iterable): + raise ValueError( + "transformer_sampling must be an iterable if it is not None" + ) + + if len(transformer_sampling) != num_layers: + raise ValueError( + "transformer_sampling {} does not match with the number " + "of layers {}".format(transformer_sampling, num_layers) + ) + + for layer, value in enumerate(transformer_sampling): + if not isinstance(value, int): + raise ValueError("Invalid value in transformer_sampling: ") + if value < 1: + raise ValueError( + "{} layer's subsampling is {}.".format(layer, value) + + " This is not allowed! " + ) + return transformer_sampling + + def slice(self, embedding, padding_mask, attn_mask, sampling_factor): + """ + embedding is a (T, B, D) tensor + padding_mask is a (B, T) tensor or None + attn_mask is a (T, T) tensor or None + """ + embedding = embedding[::sampling_factor, :, :] + if padding_mask is not None: + padding_mask = padding_mask[:, ::sampling_factor] + if attn_mask is not None: + attn_mask = attn_mask[::sampling_factor, ::sampling_factor] + + return embedding, padding_mask, attn_mask + + def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1): + """ + create attention mask according to sequence lengths and transformer + context + + Args: + - input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is + the length of b-th sequence + - subsampling_factor: int + * Note that the left_context and right_context is specified in + the input frame-level while input to transformer may already + go through subsampling (e.g., the use of striding in vggblock) + we use subsampling_factor to scale the left/right context + + Return: + - a (T, T) binary tensor or None, where T is max(input_lengths) + * if self.transformer_context is None, None + * if left_context is None, + * attn_mask[t, t + right_context + 1:] = 1 + * others = 0 + * if right_context is None, + * attn_mask[t, 0:t - left_context] = 1 + * others = 0 + * elsif + * attn_mask[t, t - left_context: t + right_context + 1] = 0 + * others = 1 + """ + if self.transformer_context is None: + return None + + maxT = torch.max(input_lengths).item() + attn_mask = torch.zeros(maxT, maxT) + + left_context = self.transformer_context[0] + right_context = self.transformer_context[1] + if left_context is not None: + left_context = math.ceil(self.transformer_context[0] / subsampling_factor) + if right_context is not None: + right_context = math.ceil(self.transformer_context[1] / subsampling_factor) + + for t in range(maxT): + if left_context is not None: + st = 0 + en = max(st, t - left_context) + attn_mask[t, st:en] = 1 + if right_context is not None: + st = t + right_context + 1 + st = min(st, maxT - 1) + attn_mask[t, st:] = 1 + + return attn_mask.to(input_lengths.device) + + def reorder_encoder_out(self, encoder_out, new_order): + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(1, new_order) + return encoder_out + + +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs. + Default: ``False`` + left_pad (bool, optional): whether the input is left-padded. Default: + ``False`` + """ + + def __init__( + self, + dictionary, + embed_dim=512, + transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, + conv_config=DEFAULT_DEC_CONV_CONFIG, + encoder_output_dim=512, + ): + + super().__init__(dictionary) + vocab_size = len(dictionary) + self.padding_idx = dictionary.pad() + self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx) + + self.conv_layers = nn.ModuleList() + for i in range(len(conv_config)): + out_channels, kernel_size, layer_norm = conv_config[i] + if i == 0: + conv_layer = LinearizedConv1d( + embed_dim, out_channels, kernel_size, padding=kernel_size - 1 + ) + else: + conv_layer = LinearizedConv1d( + conv_config[i - 1][0], + out_channels, + kernel_size, + padding=kernel_size - 1, + ) + self.conv_layers.append(conv_layer) + if layer_norm: + self.conv_layers.append(nn.LayerNorm(out_channels)) + self.conv_layers.append(nn.ReLU()) + + self.layers = nn.ModuleList() + if conv_config[-1][0] != transformer_config[0][0]: + self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0])) + self.layers.append( + TransformerDecoderLayer( + prepare_transformer_decoder_params(*transformer_config[0]) + ) + ) + + for i in range(1, len(transformer_config)): + if transformer_config[i - 1][0] != transformer_config[i][0]: + self.layers.append( + Linear(transformer_config[i - 1][0], transformer_config[i][0]) + ) + self.layers.append( + TransformerDecoderLayer( + prepare_transformer_decoder_params(*transformer_config[i]) + ) + ) + self.fc_out = Linear(transformer_config[-1][0], vocab_size) + + def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for input feeding/teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + Returns: + tuple: + - the last decoder layer's output of shape `(batch, tgt_len, + vocab)` + - the last decoder layer's attention weights of shape `(batch, + tgt_len, src_len)` + """ + target_padding_mask = ( + (prev_output_tokens == self.padding_idx).to(prev_output_tokens.device) + if incremental_state is None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + + # embed tokens + x = self.embed_tokens(prev_output_tokens) + + # B x T x C -> T x B x C + x = self._transpose_if_training(x, incremental_state) + + for layer in self.conv_layers: + if isinstance(layer, LinearizedConvolution): + x = layer(x, incremental_state) + else: + x = layer(x) + + # B x T x C -> T x B x C + x = self._transpose_if_inference(x, incremental_state) + + # decoder layers + for layer in self.layers: + if isinstance(layer, TransformerDecoderLayer): + x, *_ = layer( + x, + (encoder_out["encoder_out"] if encoder_out is not None else None), + ( + encoder_out["encoder_padding_mask"].t() + if encoder_out["encoder_padding_mask"] is not None + else None + ), + incremental_state, + self_attn_mask=( + self.buffered_future_mask(x) + if incremental_state is None + else None + ), + self_attn_padding_mask=( + target_padding_mask if incremental_state is None else None + ), + ) + else: + x = layer(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + x = self.fc_out(x) + + return x, None + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + if self._future_mask.size(0) < dim: + self._future_mask = torch.triu( + utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def _transpose_if_training(self, x, incremental_state): + if incremental_state is None: + x = x.transpose(0, 1) + return x + + def _transpose_if_inference(self, x, incremental_state): + if incremental_state: + x = x.transpose(0, 1) + return x + + +@register_model("asr_vggtransformer_encoder") +class VGGTransformerEncoderModel(FairseqEncoderModel): + def __init__(self, encoder): + super().__init__(encoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--input-feat-per-channel", + type=int, + metavar="N", + help="encoder input dimension per input channel", + ) + parser.add_argument( + "--vggblock-enc-config", + type=str, + metavar="EXPR", + help=""" + an array of tuples each containing the configuration of one vggblock + [(out_channels, conv_kernel_size, pooling_kernel_size,num_conv_layers), ...] + """, + ) + parser.add_argument( + "--transformer-enc-config", + type=str, + metavar="EXPR", + help=""" + a tuple containing the configuration of the Transformer layers + configurations: + [(input_dim, + num_heads, + ffn_dim, + normalize_before, + dropout, + attention_dropout, + relu_dropout), ]""", + ) + parser.add_argument( + "--enc-output-dim", + type=int, + metavar="N", + help="encoder output dimension, projecting the LSTM output", + ) + parser.add_argument( + "--in-channels", + type=int, + metavar="N", + help="number of encoder input channels", + ) + parser.add_argument( + "--transformer-context", + type=str, + metavar="EXPR", + help=""" + either None or a tuple of two ints, indicating left/right context a + transformer can have access to""", + ) + parser.add_argument( + "--transformer-sampling", + type=str, + metavar="EXPR", + help=""" + either None or a tuple of ints, indicating sampling factor in each layer""", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + base_architecture_enconly(args) + encoder = VGGTransformerEncoderOnly( + vocab_size=len(task.target_dictionary), + input_feat_per_channel=args.input_feat_per_channel, + vggblock_config=eval(args.vggblock_enc_config), + transformer_config=eval(args.transformer_enc_config), + encoder_output_dim=args.enc_output_dim, + in_channels=args.in_channels, + transformer_context=eval(args.transformer_context), + transformer_sampling=eval(args.transformer_sampling), + ) + return cls(encoder) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + # net_output['encoder_out'] is a (T, B, D) tensor + lprobs = super().get_normalized_probs(net_output, log_probs, sample) + # lprobs is a (T, B, D) tensor + # we need to transoose to get (B, T, D) tensor + lprobs = lprobs.transpose(0, 1).contiguous() + lprobs.batch_first = True + return lprobs + + +class VGGTransformerEncoderOnly(VGGTransformerEncoder): + def __init__( + self, + vocab_size, + input_feat_per_channel, + vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG, + transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG, + encoder_output_dim=512, + in_channels=1, + transformer_context=None, + transformer_sampling=None, + ): + super().__init__( + input_feat_per_channel=input_feat_per_channel, + vggblock_config=vggblock_config, + transformer_config=transformer_config, + encoder_output_dim=encoder_output_dim, + in_channels=in_channels, + transformer_context=transformer_context, + transformer_sampling=transformer_sampling, + ) + self.fc_out = Linear(self.encoder_output_dim, vocab_size) + + def forward(self, src_tokens, src_lengths, **kwargs): + """ + src_tokens: padded tensor (B, T, C * feat) + src_lengths: tensor of original lengths of input utterances (B,) + """ + + enc_out = super().forward(src_tokens, src_lengths) + x = self.fc_out(enc_out["encoder_out"]) + # x = F.log_softmax(x, dim=-1) + # Note: no need this line, because model.get_normalized_prob will call + # log_softmax + return { + "encoder_out": x, # (T, B, C) + "encoder_padding_mask": enc_out["encoder_padding_mask"], # (T, B) + } + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return (1e6, 1e6) # an arbitrary large number + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + # nn.init.uniform_(m.weight, -0.1, 0.1) + # nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True, dropout=0): + """Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features, bias=bias) + # m.weight.data.uniform_(-0.1, 0.1) + # if bias: + # m.bias.data.uniform_(-0.1, 0.1) + return m + + +def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): + """Weight-normalized Conv1d layer optimized for decoding""" + m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) + std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) + nn.init.normal_(m.weight, mean=0, std=std) + nn.init.constant_(m.bias, 0) + return nn.utils.weight_norm(m, dim=2) + + +def LayerNorm(embedding_dim): + m = nn.LayerNorm(embedding_dim) + return m + + +# seq2seq models +def base_architecture(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG + ) + args.transformer_enc_config = getattr( + args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG + ) + args.enc_output_dim = getattr(args, "enc_output_dim", 512) + args.in_channels = getattr(args, "in_channels", 1) + args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128) + args.transformer_dec_config = getattr( + args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG + ) + args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG) + args.transformer_context = getattr(args, "transformer_context", "None") + + +@register_model_architecture("asr_vggtransformer", "vggtransformer_1") +def vggtransformer_1(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" + ) + args.transformer_enc_config = getattr( + args, + "transformer_enc_config", + "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14", + ) + args.enc_output_dim = getattr(args, "enc_output_dim", 1024) + args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128) + args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") + args.transformer_dec_config = getattr( + args, + "transformer_dec_config", + "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4", + ) + + +@register_model_architecture("asr_vggtransformer", "vggtransformer_2") +def vggtransformer_2(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" + ) + args.transformer_enc_config = getattr( + args, + "transformer_enc_config", + "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16", + ) + args.enc_output_dim = getattr(args, "enc_output_dim", 1024) + args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512) + args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") + args.transformer_dec_config = getattr( + args, + "transformer_dec_config", + "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6", + ) + + +@register_model_architecture("asr_vggtransformer", "vggtransformer_base") +def vggtransformer_base(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" + ) + args.transformer_enc_config = getattr( + args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12" + ) + + args.enc_output_dim = getattr(args, "enc_output_dim", 512) + args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512) + args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4") + args.transformer_dec_config = getattr( + args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6" + ) + # Size estimations: + # Encoder: + # - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K + # Transformer: + # - input dimension adapter: 2560 x 512 -> 1.31M + # - transformer_layers (x12) --> 37.74M + # * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M + # * FFN weight: 512*2048*2 = 2.097M + # - output dimension adapter: 512 x 512 -> 0.26 M + # Decoder: + # - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3 + # - transformer_layer: (x6) --> 25.16M + # * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M + # * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M + # * FFN: 512*2048*2 = 2.097M + # Final FC: + # - FC: 512*5000 = 256K (assuming vocab size 5K) + # In total: + # ~65 M + + +# CTC models +def base_architecture_enconly(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", "[(32, 3, 2, 2, True)] * 2" + ) + args.transformer_enc_config = getattr( + args, "transformer_enc_config", "((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2" + ) + args.enc_output_dim = getattr(args, "enc_output_dim", 512) + args.in_channels = getattr(args, "in_channels", 1) + args.transformer_context = getattr(args, "transformer_context", "None") + args.transformer_sampling = getattr(args, "transformer_sampling", "None") + + +@register_model_architecture("asr_vggtransformer_encoder", "vggtransformer_enc_1") +def vggtransformer_enc_1(args): + # vggtransformer_1 is the same as vggtransformer_enc_big, except the number + # of layers is increased to 16 + # keep it here for backward compatiablity purpose + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.vggblock_enc_config = getattr( + args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]" + ) + args.transformer_enc_config = getattr( + args, + "transformer_enc_config", + "((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16", + ) + args.enc_output_dim = getattr(args, "enc_output_dim", 1024) diff --git a/examples/speech_recognition/models/w2l_conv_glu_enc.py b/examples/speech_recognition/models/w2l_conv_glu_enc.py new file mode 100644 index 0000000..655a9b0 --- /dev/null +++ b/examples/speech_recognition/models/w2l_conv_glu_enc.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderModel, + register_model, + register_model_architecture, +) +from fairseq.modules.fairseq_dropout import FairseqDropout + + +default_conv_enc_config = """[ + (400, 13, 170, 0.2), + (440, 14, 0, 0.214), + (484, 15, 0, 0.22898), + (532, 16, 0, 0.2450086), + (584, 17, 0, 0.262159202), + (642, 18, 0, 0.28051034614), + (706, 19, 0, 0.30014607037), + (776, 20, 0, 0.321156295296), + (852, 21, 0, 0.343637235966), + (936, 22, 0, 0.367691842484), + (1028, 23, 0, 0.393430271458), + (1130, 24, 0, 0.42097039046), + (1242, 25, 0, 0.450438317792), + (1366, 26, 0, 0.481969000038), + (1502, 27, 0, 0.51570683004), + (1652, 28, 0, 0.551806308143), + (1816, 29, 0, 0.590432749713), +]""" + + +@register_model("asr_w2l_conv_glu_encoder") +class W2lConvGluEncoderModel(FairseqEncoderModel): + def __init__(self, encoder): + super().__init__(encoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--input-feat-per-channel", + type=int, + metavar="N", + help="encoder input dimension per input channel", + ) + parser.add_argument( + "--in-channels", + type=int, + metavar="N", + help="number of encoder input channels", + ) + parser.add_argument( + "--conv-enc-config", + type=str, + metavar="EXPR", + help=""" + an array of tuples each containing the configuration of one conv layer + [(out_channels, kernel_size, padding, dropout), ...] + """, + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config) + encoder = W2lConvGluEncoder( + vocab_size=len(task.target_dictionary), + input_feat_per_channel=args.input_feat_per_channel, + in_channels=args.in_channels, + conv_enc_config=eval(conv_enc_config), + ) + return cls(encoder) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + lprobs = super().get_normalized_probs(net_output, log_probs, sample) + lprobs.batch_first = False + return lprobs + + +class W2lConvGluEncoder(FairseqEncoder): + def __init__( + self, vocab_size, input_feat_per_channel, in_channels, conv_enc_config + ): + super().__init__(None) + + self.input_dim = input_feat_per_channel + if in_channels != 1: + raise ValueError("only 1 input channel is currently supported") + + self.conv_layers = nn.ModuleList() + self.linear_layers = nn.ModuleList() + self.dropouts = [] + cur_channels = input_feat_per_channel + + for out_channels, kernel_size, padding, dropout in conv_enc_config: + layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding) + layer.weight.data.mul_(math.sqrt(3)) # match wav2letter init + self.conv_layers.append(nn.utils.weight_norm(layer)) + self.dropouts.append( + FairseqDropout(dropout, module_name=self.__class__.__name__) + ) + if out_channels % 2 != 0: + raise ValueError("odd # of out_channels is incompatible with GLU") + cur_channels = out_channels // 2 # halved by GLU + + for out_channels in [2 * cur_channels, vocab_size]: + layer = nn.Linear(cur_channels, out_channels) + layer.weight.data.mul_(math.sqrt(3)) + self.linear_layers.append(nn.utils.weight_norm(layer)) + cur_channels = out_channels // 2 + + def forward(self, src_tokens, src_lengths, **kwargs): + + """ + src_tokens: padded tensor (B, T, C * feat) + src_lengths: tensor of original lengths of input utterances (B,) + """ + B, T, _ = src_tokens.size() + x = src_tokens.transpose(1, 2).contiguous() # (B, feat, T) assuming C == 1 + + for layer_idx in range(len(self.conv_layers)): + x = self.conv_layers[layer_idx](x) + x = F.glu(x, dim=1) + x = self.dropouts[layer_idx](x) + + x = x.transpose(1, 2).contiguous() # (B, T, 908) + x = self.linear_layers[0](x) + x = F.glu(x, dim=2) + x = self.dropouts[-1](x) + x = self.linear_layers[1](x) + + assert x.size(0) == B + assert x.size(1) == T + + encoder_out = x.transpose(0, 1) # (T, B, vocab_size) + + # need to debug this -- find a simpler/elegant way in pytorch APIs + encoder_padding_mask = ( + torch.arange(T).view(1, T).expand(B, -1).to(x.device) + >= src_lengths.view(B, 1).expand(-1, T) + ).t() # (B x T) -> (T x B) + + return { + "encoder_out": encoder_out, # (T, B, vocab_size) + "encoder_padding_mask": encoder_padding_mask, # (T, B) + } + + def reorder_encoder_out(self, encoder_out, new_order): + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(1, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return (1e6, 1e6) # an arbitrary large number + + +@register_model_architecture("asr_w2l_conv_glu_encoder", "w2l_conv_glu_enc") +def w2l_conv_glu_enc(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.in_channels = getattr(args, "in_channels", 1) + args.conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config) diff --git a/examples/speech_recognition/tasks/__init__.py b/examples/speech_recognition/tasks/__init__.py new file mode 100644 index 0000000..ffa5f3b --- /dev/null +++ b/examples/speech_recognition/tasks/__init__.py @@ -0,0 +1,8 @@ +import importlib +import os + + +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + task_name = file[: file.find(".py")] + importlib.import_module("examples.speech_recognition.tasks." + task_name) diff --git a/examples/speech_recognition/tasks/speech_recognition.py b/examples/speech_recognition/tasks/speech_recognition.py new file mode 100644 index 0000000..d9f011d --- /dev/null +++ b/examples/speech_recognition/tasks/speech_recognition.py @@ -0,0 +1,157 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import re +import sys + +import torch +from examples.speech_recognition.data import AsrDataset +from examples.speech_recognition.data.replabels import replabel_symbol +from fairseq.data import Dictionary +from fairseq.tasks import LegacyFairseqTask, register_task + + +def get_asr_dataset_from_json(data_json_path, tgt_dict): + """ + Parse data json and create dataset. + See scripts/asr_prep_json.py which pack json from raw files + + Json example: + { + "utts": { + "4771-29403-0025": { + "input": { + "length_ms": 170, + "path": "/tmp/file1.flac" + }, + "output": { + "text": "HELLO \n", + "token": "HE LLO", + "tokenid": "4815, 861" + } + }, + "1564-142299-0096": { + ... + } + } + """ + if not os.path.isfile(data_json_path): + raise FileNotFoundError("Dataset not found: {}".format(data_json_path)) + with open(data_json_path, "rb") as f: + data_samples = json.load(f)["utts"] + assert len(data_samples) != 0 + sorted_samples = sorted( + data_samples.items(), + key=lambda sample: int(sample[1]["input"]["length_ms"]), + reverse=True, + ) + aud_paths = [s[1]["input"]["path"] for s in sorted_samples] + ids = [s[0] for s in sorted_samples] + speakers = [] + for s in sorted_samples: + m = re.search("(.+?)-(.+?)-(.+?)", s[0]) + speakers.append(m.group(1) + "_" + m.group(2)) + frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples] + tgt = [ + [int(i) for i in s[1]["output"]["tokenid"].split(", ")] + for s in sorted_samples + ] + # append eos + tgt = [[*t, tgt_dict.eos()] for t in tgt] + return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers) + + +@register_task("speech_recognition") +class SpeechRecognitionTask(LegacyFairseqTask): + """ + Task for training speech recognition model. + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("data", help="path to data directory") + parser.add_argument( + "--silence-token", default="\u2581", help="token for silence (used by w2l)" + ) + parser.add_argument( + "--max-source-positions", + default=sys.maxsize, + type=int, + metavar="N", + help="max number of frames in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + + def __init__(self, args, tgt_dict): + super().__init__(args) + self.tgt_dict = tgt_dict + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries).""" + dict_path = os.path.join(args.data, "dict.txt") + if not os.path.isfile(dict_path): + raise FileNotFoundError("Dict not found: {}".format(dict_path)) + tgt_dict = Dictionary.load(dict_path) + + if args.criterion == "ctc_loss": + tgt_dict.add_symbol("") + elif args.criterion == "asg_loss": + for i in range(1, args.max_replabel + 1): + tgt_dict.add_symbol(replabel_symbol(i)) + + print("| dictionary: {} types".format(len(tgt_dict))) + return cls(args, tgt_dict) + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + data_json_path = os.path.join(self.args.data, "{}.json".format(split)) + self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict) + + def build_generator(self, models, args, **unused): + w2l_decoder = getattr(args, "w2l_decoder", None) + if w2l_decoder == "viterbi": + from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder + + return W2lViterbiDecoder(args, self.target_dictionary) + elif w2l_decoder == "kenlm": + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + return W2lKenLMDecoder(args, self.target_dictionary) + elif w2l_decoder == "fairseqlm": + from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder + + return W2lFairseqLMDecoder(args, self.target_dictionary) + else: + return super().build_generator(models, args) + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.tgt_dict + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary` (if applicable + for this task).""" + return None + + def max_positions(self): + """Return the max speech and sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) diff --git a/examples/speech_recognition/utils/wer_utils.py b/examples/speech_recognition/utils/wer_utils.py new file mode 100644 index 0000000..cf6f3d0 --- /dev/null +++ b/examples/speech_recognition/utils/wer_utils.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import re +from collections import deque +from enum import Enum + +import numpy as np + + +""" + Utility modules for computation of Word Error Rate, + Alignments, as well as more granular metrics like + deletion, insersion and substitutions. +""" + + +class Code(Enum): + match = 1 + substitution = 2 + insertion = 3 + deletion = 4 + + +class Token(object): + def __init__(self, lbl="", st=np.nan, en=np.nan): + if np.isnan(st): + self.label, self.start, self.end = "", 0.0, 0.0 + else: + self.label, self.start, self.end = lbl, st, en + + +class AlignmentResult(object): + def __init__(self, refs, hyps, codes, score): + self.refs = refs # std::deque + self.hyps = hyps # std::deque + self.codes = codes # std::deque + self.score = score # float + + +def coordinate_to_offset(row, col, ncols): + return int(row * ncols + col) + + +def offset_to_row(offset, ncols): + return int(offset / ncols) + + +def offset_to_col(offset, ncols): + return int(offset % ncols) + + +def trimWhitespace(str): + return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str))) + + +def str2toks(str): + pieces = trimWhitespace(str).split(" ") + toks = [] + for p in pieces: + toks.append(Token(p, 0.0, 0.0)) + return toks + + +class EditDistance(object): + def __init__(self, time_mediated): + self.time_mediated_ = time_mediated + self.scores_ = np.nan # Eigen::Matrix + self.backtraces_ = ( + np.nan + ) # Eigen::Matrix backtraces_; + self.confusion_pairs_ = {} + + def cost(self, ref, hyp, code): + if self.time_mediated_: + if code == Code.match: + return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + elif code == Code.insertion: + return hyp.end - hyp.start + elif code == Code.deletion: + return ref.end - ref.start + else: # substitution + return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1 + else: + if code == Code.match: + return 0 + elif code == Code.insertion or code == Code.deletion: + return 3 + else: # substitution + return 4 + + def get_result(self, refs, hyps): + res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan) + + num_rows, num_cols = self.scores_.shape + res.score = self.scores_[num_rows - 1, num_cols - 1] + + curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols) + + while curr_offset != 0: + curr_row = offset_to_row(curr_offset, num_cols) + curr_col = offset_to_col(curr_offset, num_cols) + + prev_offset = self.backtraces_[curr_row, curr_col] + + prev_row = offset_to_row(prev_offset, num_cols) + prev_col = offset_to_col(prev_offset, num_cols) + + res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++ + res.hyps.appendleft(curr_col - 1) + if curr_row - 1 == prev_row and curr_col == prev_col: + res.codes.appendleft(Code.deletion) + elif curr_row == prev_row and curr_col - 1 == prev_col: + res.codes.appendleft(Code.insertion) + else: + # assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col) + ref_str = refs[res.refs[0]].label + hyp_str = hyps[res.hyps[0]].label + + if ref_str == hyp_str: + res.codes.appendleft(Code.match) + else: + res.codes.appendleft(Code.substitution) + + confusion_pair = "%s -> %s" % (ref_str, hyp_str) + if confusion_pair not in self.confusion_pairs_: + self.confusion_pairs_[confusion_pair] = 1 + else: + self.confusion_pairs_[confusion_pair] += 1 + + curr_offset = prev_offset + + return res + + def align(self, refs, hyps): + if len(refs) == 0 and len(hyps) == 0: + return np.nan + + # NOTE: we're not resetting the values in these matrices because every value + # will be overridden in the loop below. If this assumption doesn't hold, + # be sure to set all entries in self.scores_ and self.backtraces_ to 0. + self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1)) + self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1)) + + num_rows, num_cols = self.scores_.shape + + for i in range(num_rows): + for j in range(num_cols): + if i == 0 and j == 0: + self.scores_[i, j] = 0.0 + self.backtraces_[i, j] = 0 + continue + + if i == 0: + self.scores_[i, j] = self.scores_[i, j - 1] + self.cost( + None, hyps[j - 1], Code.insertion + ) + self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols) + continue + + if j == 0: + self.scores_[i, j] = self.scores_[i - 1, j] + self.cost( + refs[i - 1], None, Code.deletion + ) + self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols) + continue + + # Below here both i and j are greater than 0 + ref = refs[i - 1] + hyp = hyps[j - 1] + best_score = self.scores_[i - 1, j - 1] + ( + self.cost(ref, hyp, Code.match) + if (ref.label == hyp.label) + else self.cost(ref, hyp, Code.substitution) + ) + + prev_row = i - 1 + prev_col = j - 1 + ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion) + if ins < best_score: + best_score = ins + prev_row = i + prev_col = j - 1 + + delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion) + if delt < best_score: + best_score = delt + prev_row = i - 1 + prev_col = j + + self.scores_[i, j] = best_score + self.backtraces_[i, j] = coordinate_to_offset( + prev_row, prev_col, num_cols + ) + + return self.get_result(refs, hyps) + + +class WERTransformer(object): + def __init__(self, hyp_str, ref_str, verbose=True): + self.ed_ = EditDistance(False) + self.id2oracle_errs_ = {} + self.utts_ = 0 + self.words_ = 0 + self.insertions_ = 0 + self.deletions_ = 0 + self.substitutions_ = 0 + + self.process(["dummy_str", hyp_str, ref_str]) + + if verbose: + print("'%s' vs '%s'" % (hyp_str, ref_str)) + self.report_result() + + def process(self, input): # std::vector&& input + if len(input) < 3: + print( + "Input must be of the form ... , got ", + len(input), + " inputs:", + ) + return None + + # Align + # std::vector hyps; + # std::vector refs; + + hyps = str2toks(input[-2]) + refs = str2toks(input[-1]) + + alignment = self.ed_.align(refs, hyps) + if alignment is None: + print("Alignment is null") + return np.nan + + # Tally errors + ins = 0 + dels = 0 + subs = 0 + for code in alignment.codes: + if code == Code.substitution: + subs += 1 + elif code == Code.insertion: + ins += 1 + elif code == Code.deletion: + dels += 1 + + # Output + row = input + row.append(str(len(refs))) + row.append(str(ins)) + row.append(str(dels)) + row.append(str(subs)) + # print(row) + + # Accumulate + kIdIndex = 0 + kNBestSep = "/" + + pieces = input[kIdIndex].split(kNBestSep) + + if len(pieces) == 0: + print( + "Error splitting ", + input[kIdIndex], + " on '", + kNBestSep, + "', got empty list", + ) + return np.nan + + id = pieces[0] + if id not in self.id2oracle_errs_: + self.utts_ += 1 + self.words_ += len(refs) + self.insertions_ += ins + self.deletions_ += dels + self.substitutions_ += subs + self.id2oracle_errs_[id] = [ins, dels, subs] + else: + curr_err = ins + dels + subs + prev_err = np.sum(self.id2oracle_errs_[id]) + if curr_err < prev_err: + self.id2oracle_errs_[id] = [ins, dels, subs] + + return 0 + + def report_result(self): + # print("---------- Summary ---------------") + if self.words_ == 0: + print("No words counted") + return + + # 1-best + best_wer = ( + 100.0 + * (self.insertions_ + self.deletions_ + self.substitutions_) + / self.words_ + ) + + print( + "\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, " + "%0.2f%% dels, %0.2f%% subs)" + % ( + best_wer, + self.utts_, + self.words_, + 100.0 * self.insertions_ / self.words_, + 100.0 * self.deletions_ / self.words_, + 100.0 * self.substitutions_ / self.words_, + ) + ) + + def wer(self): + if self.words_ == 0: + wer = np.nan + else: + wer = ( + 100.0 + * (self.insertions_ + self.deletions_ + self.substitutions_) + / self.words_ + ) + return wer + + def stats(self): + if self.words_ == 0: + stats = {} + else: + wer = ( + 100.0 + * (self.insertions_ + self.deletions_ + self.substitutions_) + / self.words_ + ) + stats = dict( + { + "wer": wer, + "utts": self.utts_, + "numwords": self.words_, + "ins": self.insertions_, + "dels": self.deletions_, + "subs": self.substitutions_, + "confusion_pairs": self.ed_.confusion_pairs_, + } + ) + return stats + + +def calc_wer(hyp_str, ref_str): + t = WERTransformer(hyp_str, ref_str, verbose=0) + return t.wer() + + +def calc_wer_stats(hyp_str, ref_str): + t = WERTransformer(hyp_str, ref_str, verbose=0) + return t.stats() + + +def get_wer_alignment_codes(hyp_str, ref_str): + """ + INPUT: hypothesis string, reference string + OUTPUT: List of alignment codes (intermediate results from WER computation) + """ + t = WERTransformer(hyp_str, ref_str, verbose=0) + return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes + + +def merge_counts(x, y): + # Merge two hashes which have 'counts' as their values + # This can be used for example to merge confusion pair counts + # conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs']) + for k, v in y.items(): + if k not in x: + x[k] = 0 + x[k] += v + return x diff --git a/examples/speech_recognition/w2l_decoder.py b/examples/speech_recognition/w2l_decoder.py new file mode 100644 index 0000000..b89c2a9 --- /dev/null +++ b/examples/speech_recognition/w2l_decoder.py @@ -0,0 +1,453 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Wav2letter decoders. +""" + +import gc +import pdb +import itertools as it +import os.path as osp +import warnings +from collections import deque, namedtuple + +import numpy as np +import torch +from examples.speech_recognition.data.replabels import unpack_replabels +from fairseq import tasks, pdb +from fairseq.utils import apply_to_sample + + +try: + from wav2letter.common import create_word_dict, load_words + from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes + from wav2letter.decoder import ( + CriterionType, + DecoderOptions, + KenLM, + LM, + LMState, + SmearingMode, + Trie, + LexiconDecoder, + ) +except: + warnings.warn( + "wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings" + ) + LM = object + LMState = object + + +class W2lDecoder(object): + def __init__(self, args, tgt_dict): + self.tgt_dict = tgt_dict + self.vocab_size = len(tgt_dict) + self.nbest = args.nbest + + # criterion-specific init + if args.criterion == "ctc": + self.criterion_type = "ctc" + self.blank = ( + tgt_dict.index("") + if "" in tgt_dict.indices + else tgt_dict.bos() + ) + self.asg_transitions = None + elif args.criterion == "asg_loss": + self.criterion_type = CriterionType.ASG + self.blank = -1 + self.asg_transitions = args.asg_transitions + self.max_replabel = args.max_replabel + assert len(self.asg_transitions) == self.vocab_size ** 2 + else: + raise RuntimeError(f"unknown criterion: {args.criterion}") + + def generate(self, models, sample, **unused): + """Generate a batch of inferences.""" + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" + } + emissions = self.get_emissions(models, encoder_input) + return self.decode(emissions) + + def get_emissions(self, models, encoder_input): + """Run encoder and normalize emissions""" + # encoder_out = models[0].encoder(**encoder_input) + encoder_out = models[0](**encoder_input) + if self.criterion_type == "ctc": + emissions = models[0].get_normalized_probs(encoder_out, log_probs=True) + elif self.criterion_type == CriterionType.ASG: + emissions = encoder_out["encoder_out"] + return emissions.transpose(0, 1).float().cpu().contiguous() + + def get_tokens(self, idxs): + """Normalize tokens by handling CTC blank, ASG replabels, etc.""" + idxs = (g[0] for g in it.groupby(idxs)) + if self.criterion_type == "ctc": + idxs = filter(lambda x: x != self.blank, idxs) + elif self.criterion_type == CriterionType.ASG: + idxs = filter(lambda x: x >= 0, idxs) + idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel) + return torch.LongTensor(list(idxs)) + + +class W2lViterbiDecoder(W2lDecoder): + def __init__(self, args, tgt_dict): + super().__init__(args, tgt_dict) + + def decode(self, emissions): + B, T, N = emissions.size() + results = [] + """ + hypos = [] + if self.asg_transitions is None: + transitions = torch.FloatTensor(N, N).zero_() + else: + transitions = torch.FloatTensor(self.asg_transitions).view(N, N) + viterbi_path = torch.IntTensor(B, T) + workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N)) + CpuViterbiPath.compute( + B, + T, + N, + get_data_ptr_as_bytes(emissions), + get_data_ptr_as_bytes(transitions), + get_data_ptr_as_bytes(viterbi_path), + get_data_ptr_as_bytes(workspace), + ) + return [ + [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] + for b in range(B) + ] + """ + for b in range(B): + am_scores = emissions[b] + tokens = am_scores.argmax(dim=-1) + ids = [] + for i in range(T): + if tokens[i].item() != self.tgt_dict.bos(): + if i == 0: + ids.append(tokens[i].item()) + elif tokens[i].item() != tokens[i-1].item(): + ids.append(tokens[i].item()) + results.append([{"tokens": torch.LongTensor(ids), "score": 0.0}]) + return results + + + + +class W2lKenLMDecoder(W2lDecoder): + def __init__(self, args, tgt_dict): + super().__init__(args, tgt_dict) + + self.silence = ( + tgt_dict.index("") + if "" in tgt_dict.indices + else tgt_dict.bos() + ) + self.lexicon = load_words(args.lexicon) + self.word_dict = create_word_dict(self.lexicon) + self.unk_word = self.word_dict.get_index("") + + self.lm = KenLM(args.kenlm_model, self.word_dict) + self.trie = Trie(self.vocab_size, self.silence) + + start_state = self.lm.start(False) + for i, (word, spellings) in enumerate(self.lexicon.items()): + word_idx = self.word_dict.get_index(word) + _, score = self.lm.score(start_state, word_idx) + for spelling in spellings: + spelling_idxs = [tgt_dict.index(token) for token in spelling] + assert ( + tgt_dict.unk() not in spelling_idxs + ), f"{spelling} {spelling_idxs}" + self.trie.insert(spelling_idxs, word_idx, score) + self.trie.smear(SmearingMode.MAX) + + self.decoder_opts = DecoderOptions( + args.beam, + int(getattr(args, "beam_size_token", len(tgt_dict))), + args.beam_threshold, + args.lm_weight, + args.word_score, + args.unk_weight, + args.sil_weight, + 0, + False, + self.criterion_type, + ) + + if self.asg_transitions is None: + N = 768 + # self.asg_transitions = torch.FloatTensor(N, N).zero_() + self.asg_transitions = [] + + self.decoder = LexiconDecoder( + self.decoder_opts, + self.trie, + self.lm, + self.silence, + self.blank, + self.unk_word, + self.asg_transitions, + False, + ) + + def decode(self, emissions): + B, T, N = emissions.size() + hypos = [] + for b in range(B): + emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) + results = self.decoder.decode(emissions_ptr, T, N) + + nbest_results = results[: self.nbest] + hypos.append( + [ + { + "tokens": self.get_tokens(result.tokens), + "score": result.score, + "words": [ + self.word_dict.get_entry(x) for x in result.words if x >= 0 + ], + } + for result in nbest_results + ] + ) + return hypos + + +FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"]) + + +class FairseqLM(LM): + def __init__(self, dictionary, model): + LM.__init__(self) + self.dictionary = dictionary + self.model = model + self.unk = self.dictionary.unk() + + self.save_incremental = False # this currently does not work properly + self.max_cache = 20_000 + + model.cuda() + model.eval() + model.make_generation_fast_() + + self.states = {} + self.stateq = deque() + + def start(self, start_with_nothing): + state = LMState() + prefix = torch.LongTensor([[self.dictionary.eos()]]) + incremental_state = {} if self.save_incremental else None + with torch.no_grad(): + res = self.model(prefix.cuda(), incremental_state=incremental_state) + probs = self.model.get_normalized_probs(res, log_probs=True, sample=None) + + if incremental_state is not None: + incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state) + self.states[state] = FairseqLMState( + prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy() + ) + self.stateq.append(state) + + return state + + def score(self, state: LMState, token_index: int, no_cache: bool = False): + """ + Evaluate language model based on the current lm state and new word + Parameters: + ----------- + state: current lm state + token_index: index of the word + (can be lexicon index then you should store inside LM the + mapping between indices of lexicon and lm, or lm index of a word) + + Returns: + -------- + (LMState, float): pair of (new state, score for the current word) + """ + curr_state = self.states[state] + + def trim_cache(targ_size): + while len(self.stateq) > targ_size: + rem_k = self.stateq.popleft() + rem_st = self.states[rem_k] + rem_st = FairseqLMState(rem_st.prefix, None, None) + self.states[rem_k] = rem_st + + if curr_state.probs is None: + new_incremental_state = ( + curr_state.incremental_state.copy() + if curr_state.incremental_state is not None + else None + ) + with torch.no_grad(): + if new_incremental_state is not None: + new_incremental_state = apply_to_sample( + lambda x: x.cuda(), new_incremental_state + ) + elif self.save_incremental: + new_incremental_state = {} + + res = self.model( + torch.from_numpy(curr_state.prefix).cuda(), + incremental_state=new_incremental_state, + ) + probs = self.model.get_normalized_probs( + res, log_probs=True, sample=None + ) + + if new_incremental_state is not None: + new_incremental_state = apply_to_sample( + lambda x: x.cpu(), new_incremental_state + ) + + curr_state = FairseqLMState( + curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy() + ) + + if not no_cache: + self.states[state] = curr_state + self.stateq.append(state) + + score = curr_state.probs[token_index].item() + + trim_cache(self.max_cache) + + outstate = state.child(token_index) + if outstate not in self.states and not no_cache: + prefix = np.concatenate( + [curr_state.prefix, torch.LongTensor([[token_index]])], -1 + ) + incr_state = curr_state.incremental_state + + self.states[outstate] = FairseqLMState(prefix, incr_state, None) + + if token_index == self.unk: + score = float("-inf") + + return outstate, score + + def finish(self, state: LMState): + """ + Evaluate eos for language model based on the current lm state + + Returns: + -------- + (LMState, float): pair of (new state, score for the current word) + """ + return self.score(state, self.dictionary.eos()) + + def empty_cache(self): + self.states = {} + self.stateq = deque() + gc.collect() + + +class W2lFairseqLMDecoder(W2lDecoder): + def __init__(self, args, tgt_dict): + super().__init__(args, tgt_dict) + + self.silence = tgt_dict.bos() + + self.unit_lm = getattr(args, "unit_lm", False) + + self.lexicon = load_words(args.lexicon) if args.lexicon else None + self.idx_to_wrd = {} + + checkpoint = torch.load(args.kenlm_model, map_location="cpu") + lm_args = checkpoint["args"] + lm_args.data = osp.dirname(args.kenlm_model) + print(lm_args) + task = tasks.setup_task(lm_args) + model = task.build_model(lm_args) + model.load_state_dict(checkpoint["model"], strict=False) + + self.trie = Trie(self.vocab_size, self.silence) + + self.word_dict = task.dictionary + self.unk_word = self.word_dict.unk() + self.lm = FairseqLM(self.word_dict, model) + + self.decoder_opts = DecoderOptions( + args.beam, + int(getattr(args, "beam_size_token", len(tgt_dict))), + args.beam_threshold, + args.lm_weight, + args.word_score, + args.unk_weight, + args.sil_weight, + 0, + False, + self.criterion_type, + ) + + if self.lexicon: + start_state = self.lm.start(False) + for i, (word, spellings) in enumerate(self.lexicon.items()): + if self.unit_lm: + word_idx = i + self.idx_to_wrd[i] = word + score = 0 + else: + word_idx = self.word_dict.index(word) + _, score = self.lm.score(start_state, word_idx, no_cache=True) + + for spelling in spellings: + spelling_idxs = [tgt_dict.index(token) for token in spelling] + assert ( + tgt_dict.unk() not in spelling_idxs + ), f"{spelling} {spelling_idxs}" + self.trie.insert(spelling_idxs, word_idx, score) + self.trie.smear(SmearingMode.MAX) + + self.decoder = LexiconDecoder( + self.decoder_opts, + self.trie, + self.lm, + self.silence, + self.blank, + self.unk_word, + [], + self.unit_lm, + ) + else: + from wav2letter.decoder import LexiconFreeDecoder + self.decoder = LexiconFreeDecoder( + self.decoder_opts, self.lm, self.silence, self.blank, [] + ) + + def decode(self, emissions): + B, T, N = emissions.size() + hypos = [] + + def idx_to_word(idx): + if self.unit_lm: + return self.idx_to_wrd[idx] + else: + return self.word_dict[idx] + + def make_hypo(result): + hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score} + if self.lexicon: + hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0] + return hypo + + for b in range(B): + emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) + results = self.decoder.decode(emissions_ptr, T, N) + + nbest_results = results[: self.nbest] + hypos.append([make_hypo(result) for result in nbest_results]) + self.lm.empty_cache() + + return hypos diff --git a/examples/unispeech/adjust_sample_rate.py b/examples/unispeech/adjust_sample_rate.py new file mode 100644 index 0000000..07c319d --- /dev/null +++ b/examples/unispeech/adjust_sample_rate.py @@ -0,0 +1,60 @@ +import argparse +import os +import librosa +import soundfile as sf + +from pydub import AudioSegment + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--wav-path', type=str + ) + parser.add_argument( + '--dest-path', type=str + ) + parser.add_argument( + '--input', type=str + ) + parser.add_argument( + '--output', type=str + ) + + return parser + + +def main(args): + os.makedirs(args.dest_path, exist_ok=True) + + f = open(args.input) + data = f.readlines() + f.close() + + wf = open(args.output, 'w') + wf.write(args.dest_path+"\n") + count = len(data) + for line in data: + wav_name = line.strip() + wav_file = os.path.join(args.wav_path, wav_name) + base_wav_name = os.path.splitext(wav_name)[0] + output_file = os.path.join(args.dest_path, base_wav_name+".wav") + if os.path.exists(wav_file) and not os.path.exists(output_file): + sound = AudioSegment.from_mp3(wav_file) + sound.export(os.path.join(args.dest_path, 'tmp.wav'), format='wav') + y, sr = librosa.load(os.path.join(args.dest_path, 'tmp.wav'), sr=16000) + sf.write(output_file, y, sr) + infos = sf.info(output_file) + frames = infos.frames + sr = infos.samplerate + wf.write("{}\t{}\t{}\n".format(base_wav_name+".wav", frames, sr)) + count += 1 + if count % 100 == 0: + print('process {} done'.format(count)) + + wf.close() + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/examples/unispeech/libri_labels.py b/examples/unispeech/libri_labels.py new file mode 100644 index 0000000..3fa1ec4 --- /dev/null +++ b/examples/unispeech/libri_labels.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Helper script to pre-compute embeddings for a wav2letter++ dataset +""" + +import argparse +import os + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("tsv") + parser.add_argument("--output-dir", required=True) + parser.add_argument("--output-name", required=True) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + transcriptions = {} + + with open(args.tsv, "r") as tsv, open( + os.path.join(args.output_dir, args.output_name + ".ltr"), "w" + ) as ltr_out, open( + os.path.join(args.output_dir, args.output_name + ".wrd"), "w" + ) as wrd_out: + root = next(tsv).strip() + for line in tsv: + line = line.strip() + dir = os.path.dirname(line) + if dir not in transcriptions: + parts = dir.split(os.path.sep) + trans_path = f"{parts[-2]}-{parts[-1]}.trans.txt" + path = os.path.join(root, dir, trans_path) + assert os.path.exists(path) + texts = {} + with open(path, "r") as trans_f: + for tline in trans_f: + items = tline.strip().split() + texts[items[0]] = " ".join(items[1:]) + transcriptions[dir] = texts + part = os.path.basename(line).split(".")[0] + assert part in transcriptions[dir] + print(transcriptions[dir][part], file=wrd_out) + print( + " ".join(list(transcriptions[dir][part].replace(" ", "|"))) + " |", + file=ltr_out, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/unispeech/scripts/continue_pretrain.sh b/examples/unispeech/scripts/continue_pretrain.sh new file mode 100644 index 0000000..9d3e6aa --- /dev/null +++ b/examples/unispeech/scripts/continue_pretrain.sh @@ -0,0 +1,11 @@ +model_path=YOUR_MODEL_PATH +train_subset=pretrain_HOUR_16k +valid_subset=valSeqs_1.0_uniform_new_version_16k +WORLD_SIZE=8 + + +update_freq=2 + +mkdir -p ${model_path} + +python train.py --distributed-world-size ${WORLD_SIZE} --distributed-port 0 examples/unispeech/data/LANG --save-dir ${model_path} --fp16 --num-workers 10 --task audio_pretraining --criterion wav2vec --arch unispeech --train-subset ${train_subset} --valid-subset ${valid_subset} --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --normalize --extractor-mode "layer_norm" --encoder-layers 24 --encoder-embed-dim 1024 --encoder-ffn-embed-dim 4096 --encoder-attention-heads 16 --final-dim 768 --layer-norm-first --conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.1,0.999995)' --infonce --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 100000 --lr 0.0002 --warmup-updates 10000 --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 --encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 --loss-weights '[0.1, 0]' --conv-pos 128 --conv-pos-groups 16 --num-negatives 100 --cross-sample-negatives 0 --max-sample-size 250000 --min-sample-size 32000 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 1000000 --max-update 100000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d --update-freq ${update_freq} --pretrained-path PRETRAINED_MODEL_FROM_STAGE_1 --no-epoch-checkpoints --transpose --save-interval 4 --validate-interval 4 diff --git a/examples/unispeech/scripts/finetune.sh b/examples/unispeech/scripts/finetune.sh new file mode 100644 index 0000000..4bb0d6e --- /dev/null +++ b/examples/unispeech/scripts/finetune.sh @@ -0,0 +1,11 @@ +model_path=~/test +pretrained_model=/datablob/users/v-chengw/commonvoice_model/ptmtls2_en1350.large.lr1e-3.64gpu.fgm0.1/checkpoint_best.pt +train_subset=trainSeqs_1.0_uniform_new_version_16k +valid_subset=valSeqs_1.0_uniform_new_version_16k + +mkdir -p ${model_path} +WORLD_SIZE=4 +updata_freq=1 + + +python train.py --distributed-world-size $WORLD_SIZE --distributed-port 0 /datablob/users/v-chengw/data/commonvoice_20200622/resource/nl --save-dir ${model_path} --post-process word --train-subset ${train_subset} --valid-subset ${valid_subset} --no-epoch-checkpoints --best-checkpoint-metric uer --num-workers 4 --max-update 20000 --sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path ${pretrained_model} --labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.75 --layerdrop 0.1 --mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.25 --zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 2000 --validate-after-updates 2000 --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 2000 --hold-steps 8000 --decay-steps 10000 --final-lr-scale 0.05 --activation-dropout 0.1 --dropout 0.1 --attention-dropout 0.1 --final-dropout 0.1 --dropout-input 0.1 --criterion ctc --max-tokens 1000000 --seed 1337 --log-format json --log-interval 100 --ddp-backend no_c10d --fp16 --update-freq ${updata_freq} --dict-path /datablob/users/v-chengw/data/commonvoice_20200622/common_voices_splits/nl/phonesMatches_reduced.json --save-interval 10 --validate-interval 10 --normalize diff --git a/examples/unispeech/scripts/multilingal_large_pretrain.sh b/examples/unispeech/scripts/multilingal_large_pretrain.sh new file mode 100644 index 0000000..c8557c8 --- /dev/null +++ b/examples/unispeech/scripts/multilingal_large_pretrain.sh @@ -0,0 +1,13 @@ +model_path=MODEL_PATH +valid_subset=en/valid_16k +WORLD_SIZE=NUM_OF_GPUS + + +update_freq=$((64/$WORLD_SIZE)) #ngpu * update_freq = 64 + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + + +mkdir -p ${model_path} + +python -m torch.distributed.launch $DISTRIBUTED_ARGS train.py --distributed-world-size ${WORLD_SIZE} --distributed-port 0 examples/unispeech/data --save-dir ${model_path} --fp16 --num-workers 10 --task audio_pretraining --criterion wav2vec_mtl --arch unispeech --extractor-mode "layer_norm" --encoder-layers 24 --encoder-embed-dim 1024 --encoder-ffn-embed-dim 4096 --encoder-attention-heads 16 --final-dim 768 --layer-norm-first --conv-bias --logit-temp 0.1 --train-subset en/pretrain_1350_16k,es/pretrain_168_16k_sep,fr/pretrain_353_16k_sep,it/pretrain_90_16k_sep --valid-subset ${valid_subset} --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.1,0.999995)' --infonce --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 200000 --lr 0.001 --warmup-updates 25000 --mask-length 10 --mask-prob 0.5 --mask-selection static --mask-other 0 --encoder-layerdrop 0.0 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 1.0 --loss-weights '[0.1, 0]' --conv-pos 128 --conv-pos-groups 16 --num-negatives 100 --cross-sample-negatives 0 --max-sample-size 320000 --min-sample-size 32000 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 1200000 --max-update 250000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d --update-freq ${update_freq} --post-process word --labels ltr --dict-path examples/unispeech/data/en/vocab_sep.json --negatives-from-everywhere --mtlalpha 0.5 --replace-prob 0.5 --transpose --no-epoch-checkpoints --log-format json diff --git a/examples/unispeech/scripts/one2one_large_pretrain_en1350.sh b/examples/unispeech/scripts/one2one_large_pretrain_en1350.sh new file mode 100644 index 0000000..c505520 --- /dev/null +++ b/examples/unispeech/scripts/one2one_large_pretrain_en1350.sh @@ -0,0 +1,12 @@ +model_path=MODEL_PATH +train_subset=pretrain_1350_16k +valid_subset=valid_16k +WORLD_SIZE=NUM_OF_GPUS + + +update_freq=$((64/$WORLD_SIZE)) #ngpu * update_freq = 64 +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +mkdir -p ${model_path} + +python -m torch.distributed.launch $DISTRIBUTED_ARGS train.py --distributed-world-size ${WORLD_SIZE} --distributed-port 0 examples/unispeech/data/en --save-dir ${model_path} --fp16 --num-workers 10 --task audio_pretraining --criterion wav2vec_mtl --arch unispeech --extractor-mode "layer_norm" --encoder-layers 24 --encoder-embed-dim 1024 --encoder-ffn-embed-dim 4096 --encoder-attention-heads 16 --final-dim 768 --layer-norm-first --conv-bias --logit-temp 0.1 --train-subset ${train_subset} --valid-subset ${valid_subset} --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.1,0.999995)' --infonce --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 200000 --lr 0.001 --warmup-updates 25000 --mask-length 10 --mask-prob 0.5 --mask-selection static --mask-other 0 --encoder-layerdrop 0.0 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 1.0 --loss-weights '[0.1, 0]' --conv-pos 128 --conv-pos-groups 16 --num-negatives 100 --cross-sample-negatives 0 --max-sample-size 320000 --min-sample-size 32000 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 1200000 --max-update 250000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d --update-freq ${update_freq} --post-process word --labels ltr --dict-path examples/unispeech/data/en/vocab.json --negatives-from-everywhere --mtlalpha 0.5 --replace-prob 0.5 --transpose --no-epoch-checkpoints --log-format json diff --git a/examples/unispeech/unispeech_manifest.py b/examples/unispeech/unispeech_manifest.py new file mode 100644 index 0000000..3c85cf9 --- /dev/null +++ b/examples/unispeech/unispeech_manifest.py @@ -0,0 +1,46 @@ +import argparse +import os + + +def get_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument( + 'input', + type=str, + help="input .tsv file" + ) + parser.add_argument( + '--dest', + type=str, + help="output directory" + ) + + return parser + + +def main(args): + wav_names = [] + text = [] + with open(args.input) as f: + f.readline() + for line in f: + items = line.strip().split("\t") + wav_names.append(items[1]) + text.append(items[2]) + base_name = os.path.basename(args.input) + file_name = os.path.splitext(base_name)[0] + + with open(os.path.join(args.dest, file_name+'.list'), 'w') as f: + for name in wav_names: + f.write(name+"\n") + + with open(os.path.join(args.dest, file_name+'.text'), 'w') as f: + for i in range(len(wav_names)): + f.write("{}\t{}\n".format(wav_names[i], text[i])) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/examples/unispeech/wav2vec_manifest.py b/examples/unispeech/wav2vec_manifest.py new file mode 100644 index 0000000..1d27f58 --- /dev/null +++ b/examples/unispeech/wav2vec_manifest.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Data pre-processing: build vocabularies and binarize training data. +""" + +import argparse +import glob +import os +import random + +import soundfile + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "root", metavar="DIR", help="root directory containing flac files to index" + ) + parser.add_argument( + "--valid-percent", + default=0.01, + type=float, + metavar="D", + help="percentage of data to use as validation set (between 0 and 1)", + ) + parser.add_argument( + "--dest", default=".", type=str, metavar="DIR", help="output directory" + ) + parser.add_argument( + "--ext", default="flac", type=str, metavar="EXT", help="extension to look for" + ) + parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed") + parser.add_argument( + "--path-must-contain", + default=None, + type=str, + metavar="FRAG", + help="if set, path must contain this substring for a file to be included in the manifest", + ) + return parser + + +def main(args): + assert args.valid_percent >= 0 and args.valid_percent <= 1.0 + + dir_path = os.path.realpath(args.root) + search_path = os.path.join(dir_path, "**/*." + args.ext) + rand = random.Random(args.seed) + + with open(os.path.join(args.dest, "train.tsv"), "w") as train_f, open( + os.path.join(args.dest, "valid.tsv"), "w" + ) as valid_f: + print(dir_path, file=train_f) + print(dir_path, file=valid_f) + + for fname in glob.iglob(search_path, recursive=True): + file_path = os.path.realpath(fname) + + if args.path_must_contain and args.path_must_contain not in file_path: + continue + + frames = soundfile.info(fname).frames + dest = train_f if rand.random() > args.valid_percent else valid_f + print( + "{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest + ) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args) diff --git a/fairseq/__init__.py b/fairseq/__init__.py new file mode 100644 index 0000000..ccd45ad --- /dev/null +++ b/fairseq/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import os +import sys + +try: + from .version import __version__ # noqa +except ImportError: + version_txt = os.path.join(os.path.dirname(__file__), "version.txt") + with open(version_txt) as f: + __version__ = f.read().strip() + +__all__ = ["pdb"] + +# backwards compatibility to support `from fairseq.meters import AverageMeter` +from fairseq.logging import meters, metrics, progress_bar # noqa + +sys.modules["fairseq.meters"] = meters +sys.modules["fairseq.metrics"] = metrics +sys.modules["fairseq.progress_bar"] = progress_bar + +# initialize hydra +from fairseq.dataclass.initialize import hydra_init +hydra_init() + +import fairseq.criterions # noqa +import fairseq.models # noqa +import fairseq.modules # noqa +import fairseq.optim # noqa +import fairseq.optim.lr_scheduler # noqa +import fairseq.pdb # noqa +import fairseq.scoring # noqa +import fairseq.tasks # noqa +import fairseq.token_generation_constraints # noqa + +import fairseq.benchmark # noqa +import fairseq.model_parallel # noqa diff --git a/fairseq/benchmark/__init__.py b/fairseq/benchmark/__init__.py new file mode 100644 index 0000000..f658466 --- /dev/null +++ b/fairseq/benchmark/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# import models/tasks to register them +from . import dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py new file mode 100644 index 0000000..f3146b3 --- /dev/null +++ b/fairseq/benchmark/dummy_lm.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +from fairseq.data import Dictionary, FairseqDataset +from fairseq.tasks import LegacyFairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("dummy_lm") +class DummyLMTask(LegacyFairseqTask): + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("--dict-size", default=49996, type=int) + parser.add_argument("--dataset-size", default=100000, type=int) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments " + "per sample for BERT dataset", + ) + parser.add_argument("--add-bos-token", action="store_true", help="unused") + parser.add_argument( + "--max-target-positions", + default=None, + help="max number of tokens in the target sequence", + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + + dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + + seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1 + + self.dummy_src = seq[:-1] + self.dummy_tgt = seq[1:] + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task. """ + dictionary = Dictionary() + for i in range(args.dict_size): + dictionary.add_symbol("word{}".format(i)) + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if self.args.batch_size is not None: + bsz = self.args.batch_size + else: + bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) + self.datasets[split] = DummyDataset( + { + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.args.tokens_per_sample, dtype=torch.long + ), + }, + "target": torch.stack([self.dummy_tgt for _ in range(bsz)]), + "nsentences": bsz, + "ntokens": bsz * self.args.tokens_per_sample, + }, + num_items=self.args.dataset_size, + item_size=self.args.tokens_per_sample, + ) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + +class DummyDataset(FairseqDataset): + def __init__(self, batch, num_items, item_size): + super().__init__() + self.batch = batch + self.num_items = num_items + self.item_size = item_size + + def __getitem__(self, index): + return index + + def __len__(self): + return self.num_items + + def collater(self, samples): + return self.batch + + @property + def sizes(self): + return np.array([self.item_size] * self.num_items) + + def num_tokens(self, index): + return self.item_size + + def size(self, index): + return self.item_size + + def ordered_indices(self): + return np.arange(self.num_items) + + @property + def supports_prefetch(self): + return False diff --git a/fairseq/benchmark/dummy_masked_lm.py b/fairseq/benchmark/dummy_masked_lm.py new file mode 100644 index 0000000..ab506fe --- /dev/null +++ b/fairseq/benchmark/dummy_masked_lm.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +from fairseq.data import Dictionary, FairseqDataset +from fairseq.tasks import LegacyFairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("dummy_masked_lm") +class DummyMaskedLMTask(LegacyFairseqTask): + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("--dict-size", default=49995, type=int) + parser.add_argument("--dataset-size", default=100000, type=int) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments " + "per sample for BERT dataset", + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + + # add mask token + self.mask_idx = dictionary.add_symbol("") + dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + + mask_idx = 0 + pad_idx = 1 + seq = torch.arange(args.tokens_per_sample) + pad_idx + 1 + mask = torch.arange(2, args.tokens_per_sample, 7) # ~15% + src = seq.clone() + src[mask] = mask_idx + tgt = torch.full_like(seq, pad_idx) + tgt[mask] = seq[mask] + + self.dummy_src = src + self.dummy_tgt = tgt + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task. """ + dictionary = Dictionary() + for i in range(args.dict_size): + dictionary.add_symbol("word{}".format(i)) + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if self.args.batch_size is not None: + bsz = self.args.batch_size + else: + bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample) + self.datasets[split] = DummyDataset( + { + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.args.tokens_per_sample, dtype=torch.long + ), + }, + "target": torch.stack([self.dummy_tgt for _ in range(bsz)]), + "nsentences": bsz, + "ntokens": bsz * self.args.tokens_per_sample, + }, + num_items=self.args.dataset_size, + item_size=self.args.tokens_per_sample, + ) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + +class DummyDataset(FairseqDataset): + def __init__(self, batch, num_items, item_size): + super().__init__() + self.batch = batch + self.num_items = num_items + self.item_size = item_size + + def __getitem__(self, index): + return index + + def __len__(self): + return self.num_items + + def collater(self, samples): + return self.batch + + @property + def sizes(self): + return np.array([self.item_size] * self.num_items) + + def num_tokens(self, index): + return self.item_size + + def size(self, index): + return self.item_size + + def ordered_indices(self): + return np.arange(self.num_items) + + @property + def supports_prefetch(self): + return False diff --git a/fairseq/benchmark/dummy_model.py b/fairseq/benchmark/dummy_model.py new file mode 100644 index 0000000..ff26e4f --- /dev/null +++ b/fairseq/benchmark/dummy_model.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch.nn.functional as F +from fairseq.data import Dictionary +from fairseq.models import ( + FairseqDecoder, + FairseqLanguageModel, + register_model, + register_model_architecture, +) + + +@register_model("dummy_model") +class DummyModel(FairseqLanguageModel): + def __init__(self, args, encoder): + super().__init__(encoder) + self.args = args + + @staticmethod + def add_args(parser): + parser.add_argument("--num-layers", type=int, default=24) + parser.add_argument("--embed-dim", type=int, default=1024) + + @classmethod + def build_model(cls, args, task): + encoder = DummyEncoder( + num_embed=len(task.target_dictionary), + embed_dim=args.embed_dim, + num_layers=args.num_layers, + ) + return cls(args, encoder) + + def forward(self, src_tokens, masked_tokens=None, **kwargs): + return self.decoder(src_tokens, masked_tokens=masked_tokens) + + +class DummyEncoder(FairseqDecoder): + def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24): + super().__init__(Dictionary()) + self.embed = nn.Embedding( + num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0 + ) + self.layers_a = nn.ModuleList( + [ + nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection + nn.Linear(3 * embed_dim, embed_dim), # skip self-attention + nn.Linear(embed_dim, embed_dim), # output projection + nn.Dropout(), + ) + for i in range(num_layers) + ] + ) + self.layers_b = nn.ModuleList( + [ + nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 4 * embed_dim), # FFN + nn.ReLU(), + nn.Linear(4 * embed_dim, embed_dim), # FFN + nn.Dropout(0.1), + ) + for i in range(num_layers) + ] + ) + self.out_proj = nn.Linear(embed_dim, num_embed) + + def forward(self, tokens, masked_tokens=None): + x = self.embed(tokens) + for layer_a, layer_b in zip(self.layers_a, self.layers_b): + x = x + layer_a(x) + x = x + layer_b(x) + x = self.out_proj(x) + if masked_tokens is not None: + x = x[masked_tokens] + return (x,) + + def max_positions(self): + return 1024 + + def get_normalized_probs(self, net_output, log_probs, sample=None): + logits = net_output[0].float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + + +@register_model_architecture("dummy_model", "dummy_model") +def base_architecture(args): + pass diff --git a/fairseq/benchmark/dummy_mt.py b/fairseq/benchmark/dummy_mt.py new file mode 100644 index 0000000..4ca7be9 --- /dev/null +++ b/fairseq/benchmark/dummy_mt.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +from fairseq.data import Dictionary, FairseqDataset +from fairseq.tasks import LegacyFairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("dummy_mt") +class DummyMTTask(LegacyFairseqTask): + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("--dict-size", default=49996, type=int) + parser.add_argument("--dataset-size", default=100000, type=int) + parser.add_argument("--src-len", default=30, type=int) + parser.add_argument("--tgt-len", default=30, type=int) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + + dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + + self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1 + self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1 + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task. """ + dictionary = Dictionary() + for i in range(args.dict_size): + dictionary.add_symbol("word{}".format(i)) + logger.info("dictionary: {} types".format(len(dictionary))) + + args.max_source_positions = args.src_len + dictionary.pad() + 2 + args.max_target_positions = args.tgt_len + dictionary.pad() + 2 + + return cls(args, dictionary) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + item_size = max(self.args.src_len, self.args.tgt_len) + if self.args.batch_size is not None: + bsz = self.args.batch_size + else: + bsz = max(1, self.args.max_tokens // item_size) + tgt = torch.stack([self.dummy_tgt for _ in range(bsz)]) + self.datasets[split] = DummyDataset( + { + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.args.src_len, dtype=torch.long + ), + "prev_output_tokens": tgt.clone(), + }, + "target": tgt, + "nsentences": bsz, + "ntokens": bsz * self.args.tgt_len, + }, + num_items=self.args.dataset_size, + item_size=item_size, + ) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + +class DummyDataset(FairseqDataset): + def __init__(self, batch, num_items, item_size): + super().__init__() + self.batch = batch + self.num_items = num_items + self.item_size = item_size + + def __getitem__(self, index): + return index + + def __len__(self): + return self.num_items + + def collater(self, samples): + return self.batch + + @property + def sizes(self): + return np.array([self.item_size] * self.num_items) + + def num_tokens(self, index): + return self.item_size + + def size(self, index): + return self.item_size + + def ordered_indices(self): + return np.arange(self.num_items) + + @property + def supports_prefetch(self): + return False diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py new file mode 100644 index 0000000..0255c08 --- /dev/null +++ b/fairseq/binarizer.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from collections import Counter + +import torch +from fairseq.file_io import PathManager +from fairseq.tokenizer import tokenize_line + + +def safe_readline(f): + pos = f.tell() + while True: + try: + return f.readline() + except UnicodeDecodeError: + pos -= 1 + f.seek(pos) # search where this character begins + + +class Binarizer: + @staticmethod + def binarize( + filename, + dict, + consumer, + tokenize=tokenize_line, + append_eos=True, + reverse_order=False, + offset=0, + end=-1, + already_numberized=False, + ): + nseq, ntok = 0, 0 + replaced = Counter() + + def replaced_consumer(word, idx): + if idx == dict.unk_index and word != dict.unk_word: + replaced.update([word]) + + with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: + f.seek(offset) + # next(f) breaks f.tell(), hence readline() must be used + line = safe_readline(f) + while line: + if end > 0 and f.tell() > end: + break + if already_numberized: + id_strings = line.strip().split() + id_list = [int(id_string) for id_string in id_strings] + if reverse_order: + id_list.reverse() + if append_eos: + id_list.append(dict.eos()) + ids = torch.IntTensor(id_list) + else: + ids = dict.encode_line( + line=line, + line_tokenizer=tokenize, + add_if_not_exist=False, + consumer=replaced_consumer, + append_eos=append_eos, + reverse_order=reverse_order, + ) + nseq += 1 + ntok += len(ids) + consumer(ids) + line = f.readline() + return { + "nseq": nseq, + "nunk": sum(replaced.values()), + "ntok": ntok, + "replaced": replaced, + } + + @staticmethod + def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1): + nseq = 0 + + with open(PathManager.get_local_path(filename), "r") as f: + f.seek(offset) + line = safe_readline(f) + while line: + if end > 0 and f.tell() > end: + break + ids = alignment_parser(line) + nseq += 1 + consumer(ids) + line = f.readline() + return {"nseq": nseq} + + @staticmethod + def find_offsets(filename, num_chunks): + with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: + size = os.fstat(f.fileno()).st_size + chunk_size = size // num_chunks + offsets = [0 for _ in range(num_chunks + 1)] + for i in range(1, num_chunks): + f.seek(chunk_size * i) + safe_readline(f) + offsets[i] = f.tell() + return offsets diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py new file mode 100644 index 0000000..fdee84c --- /dev/null +++ b/fairseq/checkpoint_utils.py @@ -0,0 +1,610 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import collections +import logging +import os +import re +import traceback +from collections import OrderedDict +from typing import Optional, Union + +import torch +from fairseq.dataclass.configs import CheckpointConfig, FairseqConfig +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + overwrite_args_by_name, +) +from fairseq.file_io import PathManager +from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig, open_dict +from torch.serialization import default_restore_location + + +logger = logging.getLogger(__name__) + + +def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): + from fairseq import meters + + # only one worker should attempt to create the required dir + if cfg.distributed_rank == 0: + os.makedirs(cfg.save_dir, exist_ok=True) + + prev_best = getattr(save_checkpoint, "best", val_loss) + if val_loss is not None: + best_function = max if cfg.maximize_best_checkpoint_metric else min + save_checkpoint.best = best_function(val_loss, prev_best) + + if cfg.no_save: + return + + trainer.consolidate_optimizer() + + if not trainer.is_data_parallel_master: + return + + def is_better(a, b): + return a >= b if cfg.maximize_best_checkpoint_metric else a <= b + + write_timer = meters.StopwatchMeter() + write_timer.start() + + epoch = epoch_itr.epoch + end_of_epoch = epoch_itr.end_of_epoch() + updates = trainer.get_num_updates() + + suffix = cfg.checkpoint_suffix or "" + checkpoint_conds = collections.OrderedDict() + checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( + end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 + ) + checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( + not end_of_epoch + and cfg.save_interval_updates > 0 + and updates % cfg.save_interval_updates == 0 + ) + checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( + not hasattr(save_checkpoint, "best") + or is_better(val_loss, save_checkpoint.best) + ) + if val_loss is not None and cfg.keep_best_checkpoints > 0: + checkpoint_conds[ + "checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss) + ] = not hasattr(save_checkpoint, "best") or is_better( + val_loss, save_checkpoint.best + ) + checkpoint_conds[ + "checkpoint_last{}.pt".format(suffix) + ] = not cfg.no_last_checkpoints + + extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} + if hasattr(save_checkpoint, "best"): + extra_state.update({"best": save_checkpoint.best}) + + checkpoints = [ + os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + ] + if len(checkpoints) > 0: + trainer.save_checkpoint(checkpoints[0], extra_state) + for cp in checkpoints[1:]: + PathManager.copy(checkpoints[0], cp, overwrite=True) + + write_timer.stop() + logger.info( + "saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( + checkpoints[0], epoch, updates, val_loss, write_timer.sum + ) + ) + + if not end_of_epoch and cfg.keep_interval_updates > 0: + # remove old checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" + ) + for old_chk in checkpoints[cfg.keep_interval_updates :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + + if cfg.keep_last_epochs > 0: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt") + for old_chk in checkpoints[cfg.keep_last_epochs :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + + if cfg.keep_best_checkpoints > 0: + # only keep the best N checkpoints according to validation metric + checkpoints = checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( + cfg.best_checkpoint_metric + ), + ) + if not cfg.maximize_best_checkpoint_metric: + checkpoints = checkpoints[::-1] + for old_chk in checkpoints[cfg.keep_best_checkpoints :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + + +def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): + """ + Load a checkpoint and restore the training iterator. + + *passthrough_args* will be passed through to + ``trainer.get_train_iterator``. + """ + + reset_optimizer = cfg.reset_optimizer + reset_lr_scheduler = cfg.reset_lr_scheduler + optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) + reset_meters = cfg.reset_meters + reset_dataloader = cfg.reset_dataloader + + if cfg.finetune_from_model is not None and ( + reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader + ): + raise ValueError( + "--finetune-from-model can not be set together with either --reset-optimizer" + " or reset_lr_scheduler or reset_meters or reset_dataloader" + ) + + suffix = cfg.checkpoint_suffix + if ( + cfg.restore_file == "checkpoint_last.pt" + ): # default value of restore_file is 'checkpoint_last.pt' + checkpoint_path = os.path.join( + cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + ) + first_launch = not PathManager.exists(checkpoint_path) + if cfg.finetune_from_model is not None and first_launch: + # if there is no last checkpoint to restore, start the finetune from pretrained model + # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. + if PathManager.exists(cfg.finetune_from_model): + checkpoint_path = cfg.finetune_from_model + reset_optimizer = True + reset_lr_scheduler = True + reset_meters = True + reset_dataloader = True + logger.info( + f"loading pretrained model from {checkpoint_path}: " + "optimizer, lr scheduler, meters, dataloader will be reset" + ) + else: + raise ValueError( + f"--funetune-from-model {cfg.finetune_from_model} does not exist" + ) + elif cfg.model_parallel_size > 1: + checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") + else: + checkpoint_path = cfg.restore_file + + if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: + raise ValueError( + "--finetune-from-model and --restore-file (non-default value) " + "can not be specified together: " + str(cfg) + ) + + extra_state = trainer.load_checkpoint( + checkpoint_path, + reset_optimizer, + reset_lr_scheduler, + optimizer_overrides, + reset_meters=reset_meters, + ) + + if ( + extra_state is not None + and "best" in extra_state + and not reset_optimizer + and not reset_meters + ): + save_checkpoint.best = extra_state["best"] + + if extra_state is not None and not reset_dataloader: + # restore iterator from checkpoint + itr_state = extra_state["train_iterator"] + epoch_itr = trainer.get_train_iterator( + epoch=itr_state["epoch"], load_dataset=True, **passthrough_args + ) + epoch_itr.load_state_dict(itr_state) + else: + epoch_itr = trainer.get_train_iterator( + epoch=1, load_dataset=True, **passthrough_args + ) + + trainer.lr_step(epoch_itr.epoch) + + return extra_state, epoch_itr + + +def load_checkpoint_to_cpu(path, arg_overrides=None): + """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" + with open(PathManager.get_local_path(path), "rb") as f: + state = torch.load( + f, map_location=lambda s, l: default_restore_location(s, "cpu") + ) + + if "args" in state and state["args"] is not None and arg_overrides is not None: + args = state["args"] + for arg_name, arg_val in arg_overrides.items(): + setattr(args, arg_name, arg_val) + + if "cfg" in state and state["cfg"] is not None and arg_overrides is not None: + overwrite_args_by_name(state["cfg"], arg_overrides) + + state = _upgrade_state_dict(state) + return state + + +def load_model_ensemble( + filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1 +): + """Loads an ensemble of models. + + Args: + filenames (List[str]): checkpoint files to load + arg_overrides (Dict[str,Any], optional): override model args that + were used during model training + task (fairseq.tasks.FairseqTask, optional): task to use for loading + """ + assert not ( + strict and num_shards > 1 + ), "Cannot load state dict with strict=True and checkpoint shards > 1" + ensemble, args, _task = load_model_ensemble_and_task( + filenames, + arg_overrides, + task, + strict, + suffix, + num_shards, + ) + return ensemble, args + + +def load_model_ensemble_and_task( + filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1 +): + from fairseq import tasks + + assert not ( + strict and num_shards > 1 + ), "Cannot load state dict with strict=True and checkpoint shards > 1" + ensemble = [] + for filename in filenames: + orig_filename = filename + for shard_idx in range(num_shards): + if num_shards == 1: + filename = filename.replace(".pt", suffix + ".pt") + else: + filename = orig_filename[:-3] + f"_part{shard_idx}.pt" + + if not PathManager.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + state = load_checkpoint_to_cpu(filename, arg_overrides) + if "args" in state and state["args"] is not None: + cfg = convert_namespace_to_omegaconf(state["args"]) + elif "cfg" in state and state["cfg"] is not None: + cfg = state["cfg"] + else: + raise RuntimeError( + f"Neither args nor cfg exist in state keys = {state.keys()}" + ) + + if task is None: + task = tasks.setup_task(cfg.task) + + # build model for ensemble + model = task.build_model(cfg.model) + + model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) + ensemble.append(model) + return ensemble, cfg, task + + +def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): + """Retrieves all checkpoints found in `path` directory. + + Checkpoints are identified by matching filename to the specified pattern. If + the pattern contains groups, the result will be sorted by the first group in + descending order. + """ + pt_regexp = re.compile(pattern) + files = os.listdir(path) + + entries = [] + for i, f in enumerate(files): + m = pt_regexp.fullmatch(f) + if m is not None: + idx = float(m.group(1)) if len(m.groups()) > 0 else i + entries.append((idx, m.group(0))) + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + + +def torch_persistent_save(obj, f): + if isinstance(f, str): + with PathManager.open(f, "wb") as h: + torch_persistent_save(obj, h) + return + for i in range(3): + try: + return torch.save(obj, f) + except Exception: + if i == 2: + logger.error(traceback.format_exc()) + + +def save_state( + filename, + cfg: FairseqConfig, + model_state_dict, + criterion, + optimizer, + lr_scheduler, + num_updates, + optim_history=None, + extra_state=None, + **kwargs, +): + from fairseq import utils + + if optim_history is None: + optim_history = [] + if extra_state is None: + extra_state = {} + state_dict = { + "cfg": cfg, + "args": kwargs.get("args", None), + "model": model_state_dict or {}, + "optimizer_history": optim_history + + [ + { + "criterion_name": criterion.__class__.__name__, + "optimizer_name": optimizer.__class__.__name__, + "lr_scheduler_state": lr_scheduler.state_dict(), + "num_updates": num_updates, + } + ], + "extra_state": extra_state, + } + if utils.has_parameters(criterion): + state_dict["criterion"] = criterion.state_dict() + + if cfg is None: + cfg = state_dict["args"] + assert cfg is not None, "must provide cfg or args" + + if isinstance(cfg, DictConfig): + no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state + else: + no_save_optimizer_state = cfg.no_save_optimizer_state + if not no_save_optimizer_state: + state_dict["last_optimizer_state"] = optimizer.state_dict() + + with PathManager.open(filename, "wb") as f: + torch_persistent_save(state_dict, f) + + +def _upgrade_state_dict(state): + """Helper for upgrading old model checkpoints.""" + from fairseq import models, registry, tasks + + # add optimizer_history + if "optimizer_history" not in state: + state["optimizer_history"] = [ + {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]} + ] + state["last_optimizer_state"] = state["optimizer"] + del state["optimizer"] + del state["best_loss"] + # move extra_state into sub-dictionary + if "epoch" in state and "extra_state" not in state: + state["extra_state"] = { + "epoch": state["epoch"], + "batch_offset": state["batch_offset"], + "val_loss": state["val_loss"], + } + del state["epoch"] + del state["batch_offset"] + del state["val_loss"] + # reduce optimizer history's memory usage (only keep the last state) + if "optimizer" in state["optimizer_history"][-1]: + state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"] + for optim_hist in state["optimizer_history"]: + del optim_hist["optimizer"] + # record the optimizer class name + if "optimizer_name" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG" + # move best_loss into lr_scheduler_state + if "lr_scheduler_state" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["lr_scheduler_state"] = { + "best": state["optimizer_history"][-1]["best_loss"] + } + del state["optimizer_history"][-1]["best_loss"] + # keep track of number of updates + if "num_updates" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["num_updates"] = 0 + # use stateful training data iterator + if "train_iterator" not in state["extra_state"]: + state["extra_state"]["train_iterator"] = { + "epoch": state["extra_state"]["epoch"], + "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), + } + + # old model checkpoints may not have separate source/target positions + # backward compatibility, cfg updates + if "args" in state and state["args"] is not None: + # default to translation task + if not hasattr(state["args"], "task"): + state["args"].task = "translation" + # --raw-text and --lazy-load are deprecated + if getattr(state["args"], "raw_text", False): + state["args"].dataset_impl = "raw" + elif getattr(state["args"], "lazy_load", False): + state["args"].dataset_impl = "lazy" + # epochs start at 1 + if state["extra_state"]["train_iterator"] is not None: + state["extra_state"]["train_iterator"]["epoch"] = max( + state["extra_state"]["train_iterator"].get("epoch", 1), 1 + ) + + if hasattr(state["args"], "remove_bpe"): + state["args"].post_process = state["args"].remove_bpe + + state["cfg"] = convert_namespace_to_omegaconf(state["args"]) + + if "cfg" in state and state["cfg"] is not None: + with open_dict(state["cfg"]): + if state["cfg"].task is not None: + if hasattr(state["cfg"].task, "max_positions") and not hasattr( + state["cfg"].task, "max_source_positions" + ): + state["cfg"].task.max_source_positions = state[ + "cfg" + ].task.max_positions + state["cfg"].task.max_target_positions = state[ + "cfg" + ].task.max_positions + + return state + + +def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): + """Prune the given state_dict if desired for LayerDrop + (https://arxiv.org/abs/1909.11556). + + Training with LayerDrop allows models to be robust to pruning at inference + time. This function prunes state_dict to allow smaller models to be loaded + from a larger model and re-maps the existing state_dict for this to occur. + + It's called by functions that load models from checkpoints and does not + need to be called directly. + """ + arch = None + if model_cfg is not None: + arch = ( + model_cfg._name + if isinstance(model_cfg, DictConfig) + else getattr(model_cfg, "arch", None) + ) + + if not model_cfg or arch is None or arch == "ptt_transformer": + # args should not be none, but don't crash if it is. + return state_dict + + encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None) + decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None) + + if not encoder_layers_to_keep and not decoder_layers_to_keep: + return state_dict + + # apply pruning + logger.info( + "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop" + ) + + def create_pruning_pass(layers_to_keep, layer_name): + keep_layers = sorted( + int(layer_string) for layer_string in layers_to_keep.split(",") + ) + mapping_dict = {} + for i in range(len(keep_layers)): + mapping_dict[str(keep_layers[i])] = str(i) + + regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)) + return {"substitution_regex": regex, "mapping_dict": mapping_dict} + + pruning_passes = [] + if encoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder")) + if decoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder")) + + new_state_dict = {} + for layer_name in state_dict.keys(): + match = re.search(r"\.layers\.(\d+)\.", layer_name) + # if layer has no number in it, it is a supporting layer, such as an + # embedding + if not match: + new_state_dict[layer_name] = state_dict[layer_name] + continue + + # otherwise, layer should be pruned. + original_layer_number = match.group(1) + # figure out which mapping dict to replace from + for pruning_pass in pruning_passes: + if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[ + "substitution_regex" + ].search(layer_name): + new_layer_number = pruning_pass["mapping_dict"][original_layer_number] + substitution_match = pruning_pass["substitution_regex"].search( + layer_name + ) + new_state_key = ( + layer_name[: substitution_match.start(1)] + + new_layer_number + + layer_name[substitution_match.end(1) :] + ) + new_state_dict[new_state_key] = state_dict[layer_name] + + # Since layers are now pruned, *_layers_to_keep are no longer needed. + # This is more of "It would make it work fix" rather than a proper fix. + + with open_dict(model_cfg): + if hasattr(model_cfg, "encoder_layers_to_keep"): + model_cfg.encoder_layers_to_keep = None + if hasattr(model_cfg, "decoder_layers_to_keep"): + model_cfg.decoder_layers_to_keep = None + + return new_state_dict + + +def load_pretrained_component_from_model( + component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str +): + """ + Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the + provided `component` object. If state_dict fails to load, there may be a + mismatch in the architecture of the corresponding `component` found in the + `checkpoint` file. + """ + if not PathManager.exists(checkpoint): + raise IOError("Model file not found: {}".format(checkpoint)) + state = load_checkpoint_to_cpu(checkpoint) + if isinstance(component, FairseqEncoder): + component_type = "encoder" + elif isinstance(component, FairseqDecoder): + component_type = "decoder" + else: + raise ValueError( + "component to load must be either a FairseqEncoder or " + "FairseqDecoder. Loading other component types are not supported." + ) + component_state_dict = OrderedDict() + for key in state["model"].keys(): + if key.startswith(component_type): + # encoder.input_layers.0.0.weight --> input_layers.0.0.weight + component_subkey = key[len(component_type) + 1 :] + component_state_dict[component_subkey] = state["model"][key] + component.load_state_dict(component_state_dict, strict=True) + return component + + +def verify_checkpoint_directory(save_dir: str) -> None: + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + temp_file_path = os.path.join(save_dir, "dummy") + try: + with open(temp_file_path, "w"): + pass + except OSError as e: + logger.warning( + "Unable to access checkpoint save directory: {}".format(save_dir) + ) + raise e + else: + os.remove(temp_file_path) diff --git a/fairseq/clib/libbleu/libbleu.cpp b/fairseq/clib/libbleu/libbleu.cpp new file mode 100644 index 0000000..3cf2d65 --- /dev/null +++ b/fairseq/clib/libbleu/libbleu.cpp @@ -0,0 +1,141 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +typedef struct +{ + size_t reflen; + size_t predlen; + size_t match1; + size_t count1; + size_t match2; + size_t count2; + size_t match3; + size_t count3; + size_t match4; + size_t count4; +} bleu_stat; + +// left trim (remove pad) +void bleu_ltrim(size_t* len, int** sent, int pad) { + size_t start = 0; + while(start < *len) { + if (*(*sent + start) != pad) { break; } + start++; + } + *sent += start; + *len -= start; +} + +// right trim remove (eos) +void bleu_rtrim(size_t* len, int** sent, int pad, int eos) { + size_t end = *len - 1; + while (end > 0) { + if (*(*sent + end) != eos && *(*sent + end) != pad) { break; } + end--; + } + *len = end + 1; +} + +// left and right trim +void bleu_trim(size_t* len, int** sent, int pad, int eos) { + bleu_ltrim(len, sent, pad); + bleu_rtrim(len, sent, pad, eos); +} + +size_t bleu_hash(int len, int* data) { + size_t h = 14695981039346656037ul; + size_t prime = 0x100000001b3; + char* b = (char*) data; + size_t blen = sizeof(int) * len; + + while (blen-- > 0) { + h ^= *b++; + h *= prime; + } + + return h; +} + +void bleu_addngram( + size_t *ntotal, size_t *nmatch, size_t n, + size_t reflen, int* ref, size_t predlen, int* pred) { + + if (predlen < n) { return; } + + predlen = predlen - n + 1; + (*ntotal) += predlen; + + if (reflen < n) { return; } + + reflen = reflen - n + 1; + + std::map count; + while (predlen > 0) { + size_t w = bleu_hash(n, pred++); + count[w]++; + predlen--; + } + + while (reflen > 0) { + size_t w = bleu_hash(n, ref++); + if (count[w] > 0) { + (*nmatch)++; + count[w] -=1; + } + reflen--; + } +} + +extern "C" { + +#ifdef _WIN64 +__declspec(dllexport) +#endif +void bleu_zero_init(bleu_stat* stat) { + std::memset(stat, 0, sizeof(bleu_stat)); +} + +#ifdef _WIN64 +__declspec(dllexport) +#endif +void bleu_one_init(bleu_stat* stat) { + bleu_zero_init(stat); + stat->count1 = 0; + stat->count2 = 1; + stat->count3 = 1; + stat->count4 = 1; + stat->match1 = 0; + stat->match2 = 1; + stat->match3 = 1; + stat->match4 = 1; +} + +#ifdef _WIN64 +__declspec(dllexport) +#endif +void bleu_add( + bleu_stat* stat, + size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) { + + bleu_trim(&reflen, &ref, pad, eos); + bleu_trim(&predlen, &pred, pad, eos); + stat->reflen += reflen; + stat->predlen += predlen; + + bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred); + bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred); + bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred); + bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred); +} + +} diff --git a/fairseq/clib/libbleu/module.cpp b/fairseq/clib/libbleu/module.cpp new file mode 100644 index 0000000..8ed9a84 --- /dev/null +++ b/fairseq/clib/libbleu/module.cpp @@ -0,0 +1,37 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + + +static PyMethodDef method_def[] = { + {NULL, NULL, 0, NULL} +}; + +static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "libbleu", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + method_def +}; + + +#if PY_MAJOR_VERSION == 2 +PyMODINIT_FUNC init_libbleu() +#else +PyMODINIT_FUNC PyInit_libbleu() +#endif +{ + PyObject *m = PyModule_Create(&module_def); + if (!m) { + return NULL; + } + return m; +} diff --git a/fairseq/clib/libnat/edit_dist.cpp b/fairseq/clib/libnat/edit_dist.cpp new file mode 100644 index 0000000..6bc6a93 --- /dev/null +++ b/fairseq/clib/libnat/edit_dist.cpp @@ -0,0 +1,231 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // @manual=//caffe2:torch_extension +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ::std; + +vector> edit_distance2_with_dp( + vector& x, + vector& y) { + uint32_t lx = x.size(); + uint32_t ly = y.size(); + vector> d(lx + 1, vector(ly + 1)); + for (uint32_t i = 0; i < lx + 1; i++) { + d[i][0] = i; + } + for (uint32_t j = 0; j < ly + 1; j++) { + d[0][j] = j; + } + for (uint32_t i = 1; i < lx + 1; i++) { + for (uint32_t j = 1; j < ly + 1; j++) { + d[i][j] = + min(min(d[i - 1][j], d[i][j - 1]) + 1, + d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1)); + } + } + return d; +} + +vector> edit_distance2_backtracking( + vector>& d, + vector& x, + vector& y, + uint32_t terminal_symbol) { + vector seq; + vector> edit_seqs(x.size() + 2, vector()); + /* + edit_seqs: + 0~x.size() cell is the insertion sequences + last cell is the delete sequence + */ + + if (x.size() == 0) { + edit_seqs.at(0) = y; + return edit_seqs; + } + + uint32_t i = d.size() - 1; + uint32_t j = d.at(0).size() - 1; + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) { + seq.push_back(1); // insert + seq.push_back(y.at(j - 1)); + j--; + } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) { + seq.push_back(2); // delete + seq.push_back(x.at(i - 1)); + i--; + } else { + seq.push_back(3); // keep + seq.push_back(x.at(i - 1)); + i--; + j--; + } + } + + uint32_t prev_op, op, s, word; + prev_op = 0, s = 0; + for (uint32_t k = 0; k < seq.size() / 2; k++) { + op = seq.at(seq.size() - 2 * k - 2); + word = seq.at(seq.size() - 2 * k - 1); + if (prev_op != 1) { + s++; + } + if (op == 1) // insert + { + edit_seqs.at(s - 1).push_back(word); + } else if (op == 2) // delete + { + edit_seqs.at(x.size() + 1).push_back(1); + } else { + edit_seqs.at(x.size() + 1).push_back(0); + } + + prev_op = op; + } + + for (uint32_t k = 0; k < edit_seqs.size(); k++) { + if (edit_seqs[k].size() == 0) { + edit_seqs[k].push_back(terminal_symbol); + } + } + return edit_seqs; +} + +vector> edit_distance2_backtracking_with_delete( + vector>& d, + vector& x, + vector& y, + uint32_t terminal_symbol, + uint32_t deletion_symbol) { + vector seq; + vector> edit_seqs(x.size() + 1, vector()); + /* + edit_seqs: + 0~x.size() cell is the insertion sequences + last cell is the delete sequence + */ + + if (x.size() == 0) { + edit_seqs.at(0) = y; + return edit_seqs; + } + + uint32_t i = d.size() - 1; + uint32_t j = d.at(0).size() - 1; + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) { + seq.push_back(1); // insert + seq.push_back(y.at(j - 1)); + j--; + } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) { + seq.push_back(2); // delete + seq.push_back(x.at(i - 1)); + i--; + } else { + seq.push_back(3); // keep + seq.push_back(x.at(i - 1)); + i--; + j--; + } + } + + uint32_t prev_op, op, s, word; + prev_op = 0, s = 0; + for (uint32_t k = 0; k < seq.size() / 2; k++) { + op = seq.at(seq.size() - 2 * k - 2); + word = seq.at(seq.size() - 2 * k - 1); + if (prev_op != 1) { + s++; + } + if (op == 1) // insert + { + edit_seqs.at(s - 1).push_back(word); + } else if (op == 2) // delete + { + edit_seqs.at(s - 1).push_back(deletion_symbol); + } + + prev_op = op; + } + + for (uint32_t k = 0; k < edit_seqs.size(); k++) { + if (edit_seqs.at(k).size() == 0) { + edit_seqs.at(k).push_back(terminal_symbol); + } + } + return edit_seqs; +} + +vector compute_ed2( + vector>& xs, + vector>& ys) { + vector distances(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size()); + } + return distances; +} + +vector>> suggested_ed2_path( + vector>& xs, + vector>& ys, + uint32_t terminal_symbol) { + vector>> seq(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + seq.at(i) = + edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol); + } + return seq; +} + +vector>> suggested_ed2_path_with_delete( + vector>& xs, + vector>& ys, + uint32_t terminal_symbol, + uint32_t deletion_symbol) { + vector>> seq(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + seq.at(i) = edit_distance2_backtracking_with_delete( + d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol); + } + return seq; +} + +PYBIND11_MODULE(libnat, m) { + m.def("compute_ed2", &compute_ed2, "compute_ed2"); + m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path"); + m.def( + "suggested_ed2_path_with_delete", + &suggested_ed2_path_with_delete, + "suggested_ed2_path_with_delete"); +} diff --git a/fairseq/clib/libnat_cuda/binding.cpp b/fairseq/clib/libnat_cuda/binding.cpp new file mode 100644 index 0000000..aaa6244 --- /dev/null +++ b/fairseq/clib/libnat_cuda/binding.cpp @@ -0,0 +1,60 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + This code is partially adpoted from https://github.com/1ytic/pytorch-edit-distance + */ + +#include "edit_dist.h" +#include + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + + +torch::Tensor LevenshteinDistance( + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length) { + + CHECK_INPUT(source); + CHECK_INPUT(target); + CHECK_INPUT(source_length); + CHECK_INPUT(target_length); + return LevenshteinDistanceCuda(source, target, source_length, target_length); +} + +torch::Tensor GenerateDeletionLabel( + torch::Tensor source, + torch::Tensor operations) { + + CHECK_INPUT(source); + CHECK_INPUT(operations); + return GenerateDeletionLabelCuda(source, operations); +} + +std::pair GenerateInsertionLabel( + torch::Tensor target, + torch::Tensor operations) { + + CHECK_INPUT(target); + CHECK_INPUT(operations); + return GenerateInsertionLabelCuda(target, operations); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance"); + m.def("generate_deletion_labels", &GenerateDeletionLabel, "Generate Deletion Label"); + m.def("generate_insertion_labels", &GenerateInsertionLabel, "Generate Insertion Label"); +} diff --git a/fairseq/clib/libnat_cuda/edit_dist.cu b/fairseq/clib/libnat_cuda/edit_dist.cu new file mode 100644 index 0000000..22de16b --- /dev/null +++ b/fairseq/clib/libnat_cuda/edit_dist.cu @@ -0,0 +1,332 @@ +/** +* Copyright 2017-present, Facebook, Inc. +* All rights reserved. +* +* This source code is licensed under the license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "edit_dist.h" +#include +#include +#include +#include +#include // std::pair + +template +__global__ void generate_deletion_label_kernel( + const scalar_t* __restrict__ source, + const size_t source_size, + const size_t operation_size, + int* __restrict__ operations, + int* __restrict__ labels) { + + const int index = blockIdx.x; + const int offset = index * operation_size; + const int offset_label = index * source_size; + + for (int i = 0; i < source_size; i++) { + labels[offset_label + i] = 0; + } + + int k = 0; + for (int i = 0; i < operation_size; i++){ + if (operations[offset + i] == 0){ + break; + } else if (operations[offset + i] == 1){ + continue; + } else { + labels[offset_label + k] = 3 - operations[offset + i]; + k++; + } + } +} + +template +__global__ void generate_insertion_label_kernel( + const scalar_t* __restrict__ target, + const size_t target_size, + const size_t operation_size, + int* __restrict__ operations, + int* __restrict__ labels, + int* __restrict__ masks) { + + const int index = blockIdx.x; + const int offset = index * operation_size; + const int offset_label = index * target_size; + + int k = 0; + int u = 0; + int m = 0; + + for (int i = 0; i < target_size; i++) { + labels[offset_label + i] = 0; + masks[offset_label + i] = 0; + } + + for (int i = 0; i < operation_size-1; i++){ + if (operations[offset + i] == 0){ + break; + } else if (operations[offset + i] == 2){ + continue; + } else if (operations[offset + i] == 1){ + masks[offset_label + m] = 1; + u++; m++; + } else { + labels[offset_label + k] = u; + masks[offset_label + m] = 0; + k++; m++; + u = 0; + } + } +} + +template +__global__ void levenshtein_distance_kernel( + const scalar_t* __restrict__ source, + const scalar_t* __restrict__ target, + const int* __restrict__ source_length, + const int* __restrict__ target_length, + const size_t source_size, + const size_t target_size, + int* __restrict__ operations, + int* __restrict__ errors_curr) { + + const int index = blockIdx.x; + const int offset = index * (source_size + target_size); + const int d = index * (source_size + 1) * (target_size + 1); + const int t = target_size + 1; + + auto err_idx = [d, t](int i, int j) { return d + i * t + j; }; + auto opt_idx = [offset](int k) { return offset + k; }; + + const int hyp_len = source_length[index]; + const int ref_len = target_length[index]; + const scalar_t* hyp_begin = source + index * source_size; + const scalar_t* ref_begin = target + index * target_size; + + // dynamic programming + for (int i = 0; i <= hyp_len; i++){ + errors_curr[err_idx(i, 0)] = i; + } + for (int j = 0; j <= ref_len; j++){ + errors_curr[err_idx(0, j)] = j; + } + for (int i = 1; i <= hyp_len; i++){ + for (int j = 1; j <= ref_len; j++){ + errors_curr[err_idx(i, j)] = min( + min( + errors_curr[err_idx(i-1, j)], + errors_curr[err_idx(i, j-1)] + ) + 1, + errors_curr[err_idx(i-1, j-1)] + 2 * ( + *(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1 + ) + ); + } + } + + // back-tracing + int i = hyp_len; + int j = ref_len; + int o = hyp_len + ref_len; + + for (int k = 0; k < source_size + target_size; k++) { + operations[opt_idx(k)] = 0; + } + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) { + o--; operations[opt_idx(o)] = 1; j--; // insertion + } else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) { + o--; operations[opt_idx(o)] = 2; i--; // deletion + } else { + o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing + } + } + + // moving to the left + for (int k = 0; k < hyp_len + ref_len; k++) { + if (k + o < hyp_len + ref_len){ + operations[opt_idx(k)] = operations[opt_idx(k+o)]; + } else{ + operations[opt_idx(k)] = 0; // padding + } + } + +} + +template +__global__ void faster_levenshtein_distance_kernel( + const scalar_t* __restrict__ source, + const scalar_t* __restrict__ target, + const int* __restrict__ source_length, + const int* __restrict__ target_length, + const size_t source_size, + const size_t target_size, + int* __restrict__ operations) { + + extern __shared__ short errors[]; + auto errors_curr = errors; + + const int index = blockIdx.x; + const int offset = index * (source_size + target_size); + const int t = target_size + 1; + + auto err_idx = [t](int i, int j) { return i * t + j; }; + auto opt_idx = [offset](int k) { return offset + k; }; + + const int hyp_len = source_length[index]; + const int ref_len = target_length[index]; + const scalar_t* hyp_begin = source + index * source_size; + const scalar_t* ref_begin = target + index * target_size; + + // dynamic programming + for (int i = 0; i <= hyp_len; i++){ + errors_curr[err_idx(i, 0)] = i; + } + for (int j = 0; j <= ref_len; j++){ + errors_curr[err_idx(0, j)] = j; + } + for (int i = 1; i <= hyp_len; i++){ + for (int j = 1; j <= ref_len; j++){ + errors_curr[err_idx(i, j)] = min( + min( + errors_curr[err_idx(i-1, j)], + errors_curr[err_idx(i, j-1)] + ) + 1, + errors_curr[err_idx(i-1, j-1)] + 2 * ( + *(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1 + ) + ); + } + } + + // back-tracing + int i = hyp_len; + int j = ref_len; + int o = hyp_len + ref_len; + + for (int k = 0; k < source_size + target_size; k++) { + operations[opt_idx(k)] = 0; + } + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) { + o--; operations[opt_idx(o)] = 1; j--; // insertion + } else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) { + o--; operations[opt_idx(o)] = 2; i--; // deletion + } else { + o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing + } + } + + // moving to the left + for (int k = 0; k < hyp_len + ref_len; k++) { + if (k + o < hyp_len + ref_len){ + operations[opt_idx(k)] = operations[opt_idx(k+o)]; + } else{ + operations[opt_idx(k)] = 0; // padding + } + } + +} + + +torch::Tensor GenerateDeletionLabelCuda( + torch::Tensor source, + torch::Tensor operations) { + + const auto batch_size = source.size(0); + at::TensorOptions options(source.device()); + options = options.dtype(at::ScalarType::Int); + auto labels = torch::empty({batch_size, source.size(1)}, options); + auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); + + AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] { + generate_deletion_label_kernel<<>>( + source.data_ptr(), + source.size(1), + operations.size(1), + operations.data_ptr(), + labels.data_ptr()); + })); + + return labels; +} + +std::pair GenerateInsertionLabelCuda( + torch::Tensor target, + torch::Tensor operations) { + +const auto batch_size = target.size(0); +at::TensorOptions options(target.device()); +options = options.dtype(at::ScalarType::Int); +auto labels = torch::empty({batch_size, target.size(1)}, options); +auto masks = torch::empty({batch_size, target.size(1)}, options); +auto stream = at::cuda::getCurrentCUDAStream(target.device().index()); + +AT_DISPATCH_ALL_TYPES(target.scalar_type(), "generate_insertion_labels", ([&] { + generate_insertion_label_kernel<<>>( + target.data_ptr(), + target.size(1), + operations.size(1), + operations.data_ptr(), + labels.data_ptr(), + masks.data_ptr()); +})); + +return std::make_pair(labels, masks); +} + + +torch::Tensor LevenshteinDistanceCuda( + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length) { + + const auto batch_size = source.size(0); + const auto shared_size = (source.size(1) + 1) * (target.size(1) + 1) * sizeof(short); + + at::TensorOptions options(source.device()); + options = options.dtype(at::ScalarType::Int); + auto operations = torch::empty({batch_size, source.size(1) + target.size(1)}, options); + auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); + + if (shared_size > 40000) { + auto distances = torch::empty({batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options); + AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] { + levenshtein_distance_kernel<<>>( + source.data_ptr(), + target.data_ptr(), + source_length.data_ptr(), + target_length.data_ptr(), + source.size(1), + target.size(1), + operations.data_ptr(), + distances.data_ptr()); + })); + } else { + AT_DISPATCH_ALL_TYPES(source.scalar_type(), "faster_levenshtein_distance", ([&] { + faster_levenshtein_distance_kernel<<>>( + source.data_ptr(), + target.data_ptr(), + source_length.data_ptr(), + target_length.data_ptr(), + source.size(1), + target.size(1), + operations.data_ptr()); + })); + } + + return operations; +} diff --git a/fairseq/clib/libnat_cuda/edit_dist.h b/fairseq/clib/libnat_cuda/edit_dist.h new file mode 100644 index 0000000..e3506cd --- /dev/null +++ b/fairseq/clib/libnat_cuda/edit_dist.h @@ -0,0 +1,25 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +torch::Tensor LevenshteinDistanceCuda( + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length); + +torch::Tensor GenerateDeletionLabelCuda( + torch::Tensor source, + torch::Tensor operations); + +std::pair GenerateInsertionLabelCuda( + torch::Tensor source, + torch::Tensor operations); diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py new file mode 100644 index 0000000..8cc6c0f --- /dev/null +++ b/fairseq/criterions/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os + +from fairseq import registry +from fairseq.criterions.fairseq_criterion import ( # noqa + FairseqCriterion, + LegacyFairseqCriterion, +) +from omegaconf import DictConfig + + +( + build_criterion_, + register_criterion, + CRITERION_REGISTRY, + CRITERION_DATACLASS_REGISTRY, +) = registry.setup_registry( + "--criterion", base_class=FairseqCriterion, default="cross_entropy" +) + + +def build_criterion(cfg: DictConfig, task): + return build_criterion_(cfg, task) + + +# automatically import any Python files in the criterions/ directory +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.criterions." + file_name) diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py new file mode 100644 index 0000000..b369ef2 --- /dev/null +++ b/fairseq/criterions/ctc.py @@ -0,0 +1,263 @@ +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math +from argparse import Namespace + +import torch +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import LegacyFairseqCriterion, register_criterion +from fairseq.data.data_utils import post_process +from fairseq.logging.meters import safe_round + + +@register_criterion("ctc") +class CtcCriterion(LegacyFairseqCriterion): + def __init__(self, args, task): + super().__init__(args, task) + self.blank_idx = task.target_dictionary.bos() + self.pad_idx = task.target_dictionary.pad() + self.eos_idx = task.target_dictionary.eos() + self.post_process = args.post_process + + if args.wer_args is not None: + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args) + + dec_args = Namespace() + dec_args.nbest = 1 + dec_args.criterion = "ctc" + dec_args.kenlm_model = wer_compute_kenlm + dec_args.lexicon = wer_lexicon + dec_args.beam = 50 + dec_args.beam_size_token = min(50, len(task.target_dictionary)) + dec_args.beam_threshold = min(50, len(task.target_dictionary)) + dec_args.lm_weight = lm_w + dec_args.word_score = ws_w + dec_args.unk_weight = -math.inf + dec_args.sil_weight = 0 + + self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary) + else: + self.w2l_decoder = None + + self.zero_infinity = args.zero_infinity + self.sentence_avg = args.sentence_avg + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + parser.add_argument( + "--zero-infinity", action="store_true", help="zero inf loss" + ) + try: + parser.add_argument( + "--post-process", + "--remove-bpe", + default="letter", + help="remove BPE tokens before scoring (can be set to sentencepiece, letter, and more)", + ) + except: + pass # this option might have been added from eval args + parser.add_argument( + "--wer-args", + type=str, + default=None, + help="options for wer computation on valid set using 4 gram lm. this should be a tuple of 4 elements: path to 4-gram lm, \ + path to lexicon, lm score, word score", + ) + + def get_net_output(self, model, sample): + net_output = model(**sample["net_input"]) + return net_output + + def get_loss(self, model, sample, net_output, reduce=True): + lprobs = model.get_normalized_probs( + net_output, log_probs=True + ).contiguous() # (T, B, C) from the encoder + + if "src_lengths" in sample["net_input"]: + input_lengths = sample["net_input"]["src_lengths"] + else: + non_padding_mask = ~net_output["padding_mask"] + input_lengths = non_padding_mask.long().sum(-1) + + pad_mask = (sample["target"] != self.pad_idx) & ( + sample["target"] != self.eos_idx + ) + targets_flat = sample["target"].masked_select(pad_mask) + target_lengths = sample["target_lengths"] + + with torch.backends.cudnn.flags(enabled=False): + loss = F.ctc_loss( + lprobs, + targets_flat, + input_lengths, + target_lengths, + blank=self.blank_idx, + reduction="sum", + zero_infinity=self.zero_infinity, + ) + + ntokens = ( + sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item() + ) + + sample_size = sample["target"].size(0) if self.sentence_avg else ntokens + logging_output = { + "loss": utils.item(loss.data), # * sample['ntokens'], + "ntokens": ntokens, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + } + + if not model.training: + import editdistance + + with torch.no_grad(): + lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() + + c_err = 0 + c_len = 0 + w_errs = 0 + w_len = 0 + wv_errs = 0 + for lp, t, inp_l in zip( + lprobs_t, + sample["target_label"] + if "target_label" in sample + else sample["target"], + input_lengths, + ): + lp = lp[:inp_l].unsqueeze(0) + + decoded = None + if self.w2l_decoder is not None: + decoded = self.w2l_decoder.decode(lp) + if len(decoded) < 1: + decoded = None + else: + decoded = decoded[0] + if len(decoded) < 1: + decoded = None + else: + decoded = decoded[0] + + p = (t != self.task.target_dictionary.pad()) & ( + t != self.task.target_dictionary.eos() + ) + targ = t[p] + targ_units = self.task.target_dictionary.string(targ) + targ_units_arr = targ.tolist() + + toks = lp.argmax(dim=-1).unique_consecutive() + pred_units_arr = toks[toks != self.blank_idx].tolist() + + c_err += editdistance.eval(pred_units_arr, targ_units_arr) + c_len += len(targ_units_arr) + + targ_words = post_process(targ_units, self.post_process).split() + + pred_units = self.task.target_dictionary.string(pred_units_arr) + pred_words_raw = post_process(pred_units, self.post_process).split() + + if decoded is not None and "words" in decoded: + pred_words = decoded["words"] + w_errs += editdistance.eval(pred_words, targ_words) + wv_errs += editdistance.eval(pred_words_raw, targ_words) + else: + dist = editdistance.eval(pred_words_raw, targ_words) + w_errs += dist + wv_errs += dist + + w_len += len(targ_words) + + logging_output["wv_errors"] = wv_errs + logging_output["w_errors"] = w_errs + logging_output["w_total"] = w_len + logging_output["c_errors"] = c_err + logging_output["c_total"] = c_len + + return loss, sample_size, logging_output + + def forward(self, model, sample, reduce=True): + import pdb + net_output = self.get_net_output(model, sample) + loss, sample_size, logging_output = self.get_loss(model, sample, net_output, reduce) + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + c_errors = sum(log.get("c_errors", 0) for log in logging_outputs) + metrics.log_scalar("_c_errors", c_errors) + c_total = sum(log.get("c_total", 0) for log in logging_outputs) + metrics.log_scalar("_c_total", c_total) + w_errors = sum(log.get("w_errors", 0) for log in logging_outputs) + metrics.log_scalar("_w_errors", w_errors) + wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs) + metrics.log_scalar("_wv_errors", wv_errors) + w_total = sum(log.get("w_total", 0) for log in logging_outputs) + metrics.log_scalar("_w_total", w_total) + + if c_total > 0: + metrics.log_derived( + "uer", + lambda meters: safe_round( + meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3 + ) + if meters["_c_total"].sum > 0 + else float("nan"), + ) + if w_total > 0: + metrics.log_derived( + "wer", + lambda meters: safe_round( + meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + metrics.log_derived( + "raw_wer", + lambda meters: safe_round( + meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py new file mode 100644 index 0000000..b2eda1a --- /dev/null +++ b/fairseq/criterions/fairseq_criterion.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +from typing import Any, Dict, List + +from fairseq import metrics, utils +from fairseq.dataclass.utils import gen_parser_from_dataclass +from omegaconf import DictConfig +from torch.nn.modules.loss import _Loss + + +class FairseqCriterion(_Loss): + def __init__(self, task): + super().__init__() + self.task = task + if hasattr(task, "target_dictionary"): + tgt_dict = task.target_dictionary + self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100 + + @classmethod + def add_args(cls, parser): + """Add criterion-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + @classmethod + def build_criterion(cls, cfg: DictConfig, task): + """Construct a criterion from command-line args.""" + # arguments in the __init__. + init_args = {} + for p in inspect.signature(cls).parameters.values(): + if ( + p.kind == p.POSITIONAL_ONLY + or p.kind == p.VAR_POSITIONAL + or p.kind == p.VAR_KEYWORD + ): + # we haven't implemented inference for these argument types, + # but PRs welcome :) + raise NotImplementedError("{} not supported".format(p.kind)) + + assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} + + if p.name == "task": + init_args["task"] = task + elif hasattr(cfg, p.name): + init_args[p.name] = getattr(cfg, p.name) + elif p.default != p.empty: + pass # we'll use the default value + else: + raise NotImplementedError( + "Unable to infer Criterion arguments, please implement " + "{}.build_criterion".format(cls.__name__) + ) + return cls(**init_args) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + raise NotImplementedError + + @staticmethod + def aggregate_logging_outputs( + logging_outputs: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """Aggregate logging outputs from data parallel training.""" + utils.deprecation_warning( + "The aggregate_logging_outputs API is deprecated. " + "Please use the reduce_metrics API instead." + ) + raise NotImplementedError + + @classmethod + def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: + """Aggregate logging outputs from data parallel training.""" + utils.deprecation_warning( + "Criterions should implement the reduce_metrics API. " + "Falling back to deprecated aggregate_logging_outputs API." + ) + agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs) + for k, v in agg_logging_outputs.items(): + if k in {"nsentences", "ntokens", "sample_size"}: + continue + metrics.log_scalar(k, v) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False + + +class LegacyFairseqCriterion(FairseqCriterion): + def __init__(self, args, task): + super().__init__(task=task) + self.args = args + + utils.deprecation_warning( + "Criterions should take explicit arguments instead of an " + "argparse.Namespace object, please update your criterion by " + "extending FairseqCriterion instead of LegacyFairseqCriterion." + ) + + @classmethod + def build_criterion(cls, args, task): + """Construct a criterion from command-line args.""" + return cls(args, task) diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py new file mode 100644 index 0000000..2dc7f7a --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion + + +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + if ignore_index is not None: + pad_mask = target.eq(ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + if reduce: + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + eps_i = epsilon / lprobs.size(-1) + loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss + return loss, nll_loss + + +@register_criterion("label_smoothed_cross_entropy") +class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + report_accuracy=False, + ): + super().__init__(task) + self.sentence_avg = sentence_avg + self.eps = label_smoothing + self.ignore_prefix_size = ignore_prefix_size + self.report_accuracy = report_accuracy + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--label-smoothing', default=0., type=float, metavar='D', + help='epsilon for label smoothing, 0 means no label smoothing') + parser.add_argument('--report-accuracy', action='store_true', + help='report accuracy metric') + parser.add_argument('--ignore-prefix-size', default=0, type=int, + help='Ignore first N tokens') + # fmt: on + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + return loss, sample_size, logging_output + + def get_lprobs_and_target(self, model, net_output, sample): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + if self.ignore_prefix_size > 0: + if getattr(lprobs, "batch_first", False): + lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() + target = target[:, self.ignore_prefix_size :].contiguous() + else: + lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous() + target = target[self.ignore_prefix_size :, :].contiguous() + return lprobs.view(-1, lprobs.size(-1)), target.view(-1) + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + ) + return loss, nll_loss + + def compute_accuracy(self, model, net_output, sample): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + mask = target.ne(self.padding_idx) + n_correct = torch.sum( + lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)) + ) + total = torch.sum(mask) + return n_correct, total + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + + total = utils.item(sum(log.get("total", 0) for log in logging_outputs)) + if total > 0: + metrics.log_scalar("total", total) + n_correct = utils.item( + sum(log.get("n_correct", 0) for log in logging_outputs) + ) + metrics.log_scalar("n_correct", n_correct) + metrics.log_derived( + "accuracy", + lambda meters: round( + meters["n_correct"].sum * 100.0 / meters["total"].sum, 3 + ) + if meters["total"].sum > 0 + else float("nan"), + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py new file mode 100644 index 0000000..73cfa05 --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from fairseq import metrics, utils +from fairseq.criterions import register_criterion + +from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion + + +@register_criterion("label_smoothed_cross_entropy_with_alignment") +class LabelSmoothedCrossEntropyCriterionWithAlignment( + LabelSmoothedCrossEntropyCriterion +): + def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda): + super().__init__(task, sentence_avg, label_smoothing) + self.alignment_lambda = alignment_lambda + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + LabelSmoothedCrossEntropyCriterion.add_args(parser) + parser.add_argument( + "--alignment-lambda", + default=0.05, + type=float, + metavar="D", + help="weight for the alignment loss", + ) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + + alignment_loss = None + + # Compute alignment loss only for training set and non dummy batches. + if "alignments" in sample and sample["alignments"] is not None: + alignment_loss = self.compute_alignment_loss(sample, net_output) + + if alignment_loss is not None: + logging_output["alignment_loss"] = utils.item(alignment_loss.data) + loss += self.alignment_lambda * alignment_loss + + return loss, sample_size, logging_output + + def compute_alignment_loss(self, sample, net_output): + attn_prob = net_output[1]["attn"][0] + bsz, tgt_sz, src_sz = attn_prob.shape + attn = attn_prob.view(bsz * tgt_sz, src_sz) + + align = sample["alignments"] + align_weights = sample["align_weights"].float() + + if len(align) > 0: + # Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to + # the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing. + loss = -( + (attn[align[:, 1][:, None], align[:, 0][:, None]]).log() + * align_weights[:, None] + ).sum() + else: + return None + + return loss + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + nll_loss_sum = utils.item( + sum(log.get("nll_loss", 0) for log in logging_outputs) + ) + alignment_loss_sum = utils.item( + sum(log.get("alignment_loss", 0) for log in logging_outputs) + ) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_scalar( + "alignment_loss", + alignment_loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py new file mode 100644 index 0000000..f6bec34 --- /dev/null +++ b/fairseq/criterions/wav2vec_criterion.py @@ -0,0 +1,192 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.logging.meters import safe_round + + +@register_criterion("wav2vec") +class Wav2vecCriterion(FairseqCriterion): + def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): + super().__init__(task) + self.infonce = infonce + self.loss_weights = None if loss_weights is None else eval(loss_weights) + self.log_keys = [] if log_keys is None else eval(log_keys) + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--infonce', action='store_true', + help='if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)') + parser.add_argument('--loss-weights', type=str, default=None, + help='weights for additional loss terms (not first one)') + parser.add_argument('--log-keys', type=str, default=None, + help='output keys to log') + # fmt: on + + def get_net_output(self, model, sample): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + return net_output + + def get_loss(self, model, sample, net_output, reduce=True, log_pred=False): + logits = model.get_logits(net_output).float() + target = model.get_targets(sample, net_output) + + weights = None + if hasattr(model, "get_target_weights") and not self.infonce: + weights = model.get_target_weights(target, net_output) + if torch.is_tensor(weights): + weights = weights.float() + + losses = [] + + if self.infonce: + loss = F.cross_entropy( + logits, + target, + reduction="sum" if reduce else "none", + ) + else: + loss = F.binary_cross_entropy_with_logits( + logits, + target.float(), + weights, + reduction="sum" if reduce else "none", + ) + + sample_size = target.numel() if self.infonce else target.long().sum().item() + losses.append(loss.detach().clone()) + + if self.loss_weights is not None: + assert hasattr(model, "get_extra_losses") + extra_losses = model.get_extra_losses(net_output) + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len( + self.loss_weights + ), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, coef in zip(extra_losses, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + losses.append(p) + + logging_output = { + "loss": loss.item() if reduce else loss, + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + } + + for lk in self.log_keys: + if lk in net_output: + logging_output[lk] = float((net_output[lk])) + + if len(losses) > 1: + for i, l in enumerate(losses): + logging_output[f"loss_{i}"] = l.item() + + if self.infonce: + with torch.no_grad(): + if logits.numel() == 0: + corr = 0 + count = 0 + else: + assert logits.dim() > 1, logits.shape + max = logits.argmax(-1) == 0 + min = logits.argmin(-1) == 0 + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = max.numel() + + logging_output["correct"] = corr + logging_output["count"] = count + + if log_pred: + logging_output["logits"] = logits.cpu().numpy() + logging_output["target"] = target.cpu().numpy() + return loss, sample_size, logging_output + + def forward(self, model, sample, reduce=True, log_pred=False): + net_output = self.get_net_output(model, sample) + loss, sample_size, logging_output = self.get_loss(model, sample, net_output, reduce, log_pred) + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + + correct = sum(log.get("correct", 0) for log in logging_outputs) + metrics.log_scalar("_correct", correct) + + total = sum(log.get("count", 0) for log in logging_outputs) + metrics.log_scalar("_total", total) + + if total > 0: + metrics.log_derived( + "accuracy", + lambda meters: safe_round( + meters["_correct"].sum / meters["_total"].sum, 5 + ) + if meters["_total"].sum > 0 + else float("nan"), + ) + + builtin_keys = { + "loss", + "ntokens", + "nsentences", + "sample_size", + "correct", + "count", + } + + for k in logging_outputs[0]: + if k not in builtin_keys: + val = sum(log.get(k, 0) for log in logging_outputs) / len( + logging_outputs + ) + if k.startswith("loss"): + metrics.log_scalar(k, val / sample_size / math.log(2), sample_size) + else: + metrics.log_scalar(k, val, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False diff --git a/fairseq/criterions/wav2vec_mtl.py b/fairseq/criterions/wav2vec_mtl.py new file mode 100644 index 0000000..3b6f191 --- /dev/null +++ b/fairseq/criterions/wav2vec_mtl.py @@ -0,0 +1,137 @@ +import torch +import math + +from fairseq import pdb +from fairseq import utils, metrics +from fairseq.criterions import LegacyFairseqCriterion, register_criterion +from fairseq.criterions.wav2vec_criterion import Wav2vecCriterion +from fairseq.criterions.ctc import CtcCriterion +from fairseq.logging.meters import safe_round + +@register_criterion('wav2vec_mtl') +class Wav2vecMTLCriterion(LegacyFairseqCriterion): + + def __init__(self, args, task): + super().__init__(args, task) + self.mtlalpha = args.mtlalpha + self.w2v_criterion = Wav2vecCriterion(task, args.infonce, args.loss_weights, args.log_keys) + if self.mtlalpha > 0: + self.ctc_criterion = CtcCriterion(args, task) + + @staticmethod + def add_args(parser): + Wav2vecCriterion.add_args(parser) + CtcCriterion.add_args(parser) + parser.add_argument('--mtlalpha', type=float, default=0.5) + + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + + if self.mtlalpha > 0.0: + ctc_loss, ctc_sample_size, ctc_logging_output = self.ctc_criterion.get_loss(model, sample, net_output, reduce) + else: + ctc_loss = 0 + ctc_sample_size = 0 + ctc_logging_output = {} + + infonce_loss, infonce_sample_size, infonce_logging_output = self.w2v_criterion.get_loss(model.w2v_encoder.w2v_model, sample, net_output['contrastive_res'], reduce) + loss = self.mtlalpha * ctc_loss + (1.0 - self.mtlalpha) * infonce_loss + sample_size = infonce_sample_size + logging_output = {'loss': loss, 'ntokens': ctc_logging_output['ntokens'], 'nsentences': ctc_logging_output['nsentences'], + 'ctc': ctc_logging_output, 'infonce': infonce_logging_output} + + return loss, sample_size, logging_output + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + return False + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + + ctc_loss_sum = utils.item(sum(log['ctc'].get('loss', 0) for log in logging_outputs)) + ctc_sample_size = utils.item(sum(log['ctc'].get('sample_size', 0) for log in logging_outputs)) + ctc_ntokens = utils.item(sum(log['ctc'].get('ntokens', 0) for log in logging_outputs)) + ctc_nsentences = utils.item(sum(log['ctc'].get('nsentences', 0) for log in logging_outputs)) + + ctras_loss_sum = utils.item(sum(log['infonce'].get('loss', 0) for log in logging_outputs)) + ctras_sample_size = utils.item(sum(log['infonce'].get('sample_size', 0) for log in logging_outputs)) + ctras_ntokens = utils.item(sum(log['infonce'].get('ntokens', 0) for log in logging_outputs)) + ctras_nsentences = utils.item(sum(log['infonce'].get('nsentences', 0) for log in logging_outputs)) + + metrics.log_scalar( + "loss", loss_sum, 1, round=3) + metrics.log_scalar( + "ctc_loss", ctc_loss_sum / ctc_sample_size / math.log(2), ctc_sample_size, round=3 + ) + metrics.log_scalar( + "contrastive_loss", ctras_loss_sum / ctras_sample_size / math.log(2), ctras_sample_size, round=3 + ) + if ctc_sample_size != ctc_ntokens: + metrics.log_scalar( + "nll_loss", ctc_loss_sum / ctc_ntokens / math.log(2), ctc_ntokens, round=3 + ) + c_errors = sum(log['ctc'].get("c_errors", 0) for log in logging_outputs) + metrics.log_scalar("_c_errors", c_errors) + c_total = sum(log['ctc'].get("c_total", 0) for log in logging_outputs) + metrics.log_scalar("_c_total", c_total) + w_errors = sum(log['ctc'].get("w_errors", 0) for log in logging_outputs) + metrics.log_scalar("_w_errors", w_errors) + wv_errors = sum(log['ctc'].get("wv_errors", 0) for log in logging_outputs) + metrics.log_scalar("_wv_errors", wv_errors) + w_total = sum(log['ctc'].get("w_total", 0) for log in logging_outputs) + metrics.log_scalar("_w_total", w_total) + + if c_total > 0: + metrics.log_derived( + "uer", + lambda meters: safe_round(meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3) + if meters["_c_total"].sum > 0 + else float("nan"), + ) + if w_total > 0: + metrics.log_derived( + "wer", + lambda meters: safe_round(meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + metrics.log_derived( + "raw_wer", + lambda meters: safe_round(meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + + + metrics.log_scalar("nsentences", ctras_nsentences) + metrics.log_scalar("ctc_sample_size", ctc_sample_size) + metrics.log_scalar("contrastive_sample_size", ctras_sample_size) + + + correct = sum(log['infonce'].get("correct", 0) for log in logging_outputs) + metrics.log_scalar("_correct", correct) + + total = sum(log['infonce'].get("count", 0) for log in logging_outputs) + metrics.log_scalar("_total", total) + + + if total > 0: + metrics.log_derived( + "accuracy", + lambda meters: safe_round(meters["_correct"].sum / meters["_total"].sum, 5) + if meters["_total"].sum > 0 + else float("nan"), + ) + + builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'} + for k in logging_outputs[0]['infonce']: + if k not in builtin_keys: + val = sum(log['infonce'].get(k, 0) for log in logging_outputs) / len(logging_outputs) + if k.startswith('loss'): + metrics.log_scalar(k, val / ctras_sample_size / math.log(2), ctras_sample_size) + else: + metrics.log_scalar(k, val, round=3) + diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py new file mode 100644 index 0000000..073991d --- /dev/null +++ b/fairseq/data/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +from .dictionary import Dictionary, TruncatedDictionary + +from .fairseq_dataset import FairseqDataset, FairseqIterableDataset + +from .base_wrapper_dataset import BaseWrapperDataset + +from .add_target_dataset import AddTargetDataset +from .audio.raw_audio_dataset import FileAudioDataset +from .concat_dataset import ConcatDataset +from .id_dataset import IdDataset +from .resampling_dataset import ResamplingDataset + +from .iterators import ( + CountingIterator, + EpochBatchIterator, + GroupedIterator, + ShardedIterator, +) + +__all__ = [ + "AddTargetDataset", + "ConcatDataset", + "CountingIterator", + "Dictionary", + "EpochBatchIterator", + "FairseqDataset", + "FairseqIterableDataset", + "FastaDataset", + "GroupedIterator", + "IdDataset", + "ResamplingDataset", + "ShardedIterator", +] diff --git a/fairseq/data/add_target_dataset.py b/fairseq/data/add_target_dataset.py new file mode 100644 index 0000000..9ef4670 --- /dev/null +++ b/fairseq/data/add_target_dataset.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset, data_utils + + +class AddTargetDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + labels, + pad, + eos, + batch_targets, + process_label=None, + add_to_input=False, + ): + super().__init__(dataset) + self.labels = labels + self.batch_targets = batch_targets + self.pad = pad + self.eos = eos + self.process_label = process_label + self.add_to_input = add_to_input + + def get_label(self, index): + return ( + self.labels[index] + if self.process_label is None + else self.process_label(self.labels[index]) + ) + + def __getitem__(self, index): + item = self.dataset[index] + item["label"] = self.get_label(index) + return item + + def size(self, index): + sz = self.dataset.size(index) + own_sz = len(self.get_label(index)) + return (sz, own_sz) + + def collater(self, samples): + collated = self.dataset.collater(samples) + if len(collated) == 0: + return collated + indices = set(collated["id"].tolist()) + target = [s["label"] for s in samples if s["id"] in indices] + + if self.batch_targets: + collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) + target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) + collated["ntokens"] = collated["target_lengths"].sum().item() + else: + collated["ntokens"] = sum([len(t) for t in target]) + + collated["target"] = target + + if self.add_to_input: + eos = target.new_full((target.size(0), 1), self.eos) + collated["target"] = torch.cat([target, eos], dim=-1).long() + collated["net_input"]["prev_output_tokens"] = torch.cat( + [eos, target], dim=-1 + ).long() + collated["ntokens"] += target.size(0) + return collated diff --git a/fairseq/data/audio/__init__.py b/fairseq/data/audio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py new file mode 100644 index 0000000..de08669 --- /dev/null +++ b/fairseq/data/audio/audio_utils.py @@ -0,0 +1,85 @@ +import os.path as op +from typing import BinaryIO, Optional, Tuple, Union + +import numpy as np + + +def get_waveform( + path_or_fp: Union[str, BinaryIO], normalization=True +) -> Tuple[np.ndarray, int]: + """Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC. + + Args: + path_or_fp (str or BinaryIO): the path or file-like object + normalization (bool): Normalize values to [-1, 1] (Default: True) + """ + if isinstance(path_or_fp, str): + ext = op.splitext(op.basename(path_or_fp))[1] + if ext not in {".flac", ".wav"}: + raise ValueError(f"Unsupported audio format: {ext}") + + try: + import soundfile as sf + except ImportError: + raise ImportError("Please install soundfile to load WAV/FLAC file") + + waveform, sample_rate = sf.read(path_or_fp, dtype="float32") + if not normalization: + waveform *= 2 ** 15 # denormalized to 16-bit signed integers + return waveform, sample_rate + + +def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: + """Get mel-filter bank features via PyKaldi.""" + try: + from kaldi.feat.mel import MelBanksOptions + from kaldi.feat.fbank import FbankOptions, Fbank + from kaldi.feat.window import FrameExtractionOptions + from kaldi.matrix import Vector + + mel_opts = MelBanksOptions() + mel_opts.num_bins = n_bins + frame_opts = FrameExtractionOptions() + frame_opts.samp_freq = sample_rate + opts = FbankOptions() + opts.mel_opts = mel_opts + opts.frame_opts = frame_opts + fbank = Fbank(opts=opts) + features = fbank.compute(Vector(waveform), 1.0).numpy() + return features + except ImportError: + return None + + +def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]: + """Get mel-filter bank features via TorchAudio.""" + try: + import torch + import torchaudio.compliance.kaldi as ta_kaldi + + waveform = torch.from_numpy(waveform).unsqueeze(0) + features = ta_kaldi.fbank( + waveform, num_mel_bins=n_bins, sample_frequency=sample_rate + ) + return features.numpy() + except ImportError: + return None + + +def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray: + """Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi + (faster CPP implementation) to TorchAudio (Python implementation). Note that + Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the + waveform should not be normalized.""" + sound, sample_rate = get_waveform(path_or_fp, normalization=False) + + features = _get_kaldi_fbank(sound, sample_rate, n_bins) + if features is None: + features = _get_torchaudio_fbank(sound, sample_rate, n_bins) + if features is None: + raise ImportError( + "Please install pyKaldi or torchaudio to enable " + "online filterbank feature extraction" + ) + + return features diff --git a/fairseq/data/audio/feature_transforms/__init__.py b/fairseq/data/audio/feature_transforms/__init__.py new file mode 100644 index 0000000..359fa06 --- /dev/null +++ b/fairseq/data/audio/feature_transforms/__init__.py @@ -0,0 +1,82 @@ +import importlib +import os +from abc import ABC, abstractmethod +from typing import Dict, Optional + + +class AudioFeatureTransform(ABC): + @classmethod + @abstractmethod + def from_config_dict(cls, config: Optional[Dict] = None): + pass + + +AUDIO_FEATURE_TRANSFORM_REGISTRY = {} +AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set() + + +def register_audio_feature_transform(name): + def register_audio_feature_transform_cls(cls): + if name in AUDIO_FEATURE_TRANSFORM_REGISTRY: + raise ValueError(f"Cannot register duplicate transform ({name})") + if not issubclass(cls, AudioFeatureTransform): + raise ValueError( + f"Transform ({name}: {cls.__name__}) must extend " + "AudioFeatureTransform" + ) + if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES: + raise ValueError( + f"Cannot register audio feature transform with duplicate " + f"class name ({cls.__name__})" + ) + AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls + AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__) + return cls + + return register_audio_feature_transform_cls + + +def get_audio_feature_transform(name): + return AUDIO_FEATURE_TRANSFORM_REGISTRY[name] + + +transforms_dir = os.path.dirname(__file__) +for file in os.listdir(transforms_dir): + path = os.path.join(transforms_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + name = file[: file.find(".py")] if file.endswith(".py") else file + importlib.import_module("fairseq.data.audio.feature_transforms." + name) + + +class CompositeAudioFeatureTransform(AudioFeatureTransform): + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + _transforms = _config.get("transforms") + if _transforms is None: + return None + transforms = [ + get_audio_feature_transform(_t).from_config_dict(_config.get(_t)) + for _t in _transforms + ] + return CompositeAudioFeatureTransform(transforms) + + def __init__(self, transforms): + self.transforms = [t for t in transforms if t is not None] + + def __call__(self, x): + for t in self.transforms: + x = t(x) + return x + + def __repr__(self): + format_string = ( + [self.__class__.__name__ + "("] + + [f" {t.__repr__()}" for t in self.transforms] + + [")"] + ) + return "\n".join(format_string) diff --git a/fairseq/data/audio/feature_transforms/global_cmvn.py b/fairseq/data/audio/feature_transforms/global_cmvn.py new file mode 100644 index 0000000..d512fed --- /dev/null +++ b/fairseq/data/audio/feature_transforms/global_cmvn.py @@ -0,0 +1,25 @@ +import numpy as np +from fairseq.data.audio.feature_transforms import ( + AudioFeatureTransform, + register_audio_feature_transform, +) + + +@register_audio_feature_transform("global_cmvn") +class GlobalCMVN(AudioFeatureTransform): + """Global CMVN (cepstral mean and variance normalization). The global mean + and variance need to be pre-computed and stored in NumPy format (.npz).""" + + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return GlobalCMVN(_config.get("stats_npz_path")) + + def __init__(self, stats_npz_path): + stats = np.load(stats_npz_path) + self.mean, self.std = stats["mean"], stats["std"] + + def __call__(self, x): + x = np.subtract(x, self.mean) + x = np.divide(x, self.std) + return x diff --git a/fairseq/data/audio/feature_transforms/specaugment.py b/fairseq/data/audio/feature_transforms/specaugment.py new file mode 100644 index 0000000..2ef4778 --- /dev/null +++ b/fairseq/data/audio/feature_transforms/specaugment.py @@ -0,0 +1,131 @@ +import math +import numbers +from typing import Optional + +import numpy as np +from fairseq.data.audio.feature_transforms import ( + AudioFeatureTransform, + register_audio_feature_transform, +) + + +@register_audio_feature_transform("specaugment") +class SpecAugmentTransform(AudioFeatureTransform): + """SpecAugment (https://arxiv.org/abs/1904.08779)""" + + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return SpecAugmentTransform( + _config.get("time_warp_W", 0), + _config.get("freq_mask_N", 0), + _config.get("freq_mask_F", 0), + _config.get("time_mask_N", 0), + _config.get("time_mask_T", 0), + _config.get("time_mask_p", 0.0), + _config.get("mask_value", None), + ) + + def __init__( + self, + time_warp_w: int = 0, + freq_mask_n: int = 0, + freq_mask_f: int = 0, + time_mask_n: int = 0, + time_mask_t: int = 0, + time_mask_p: float = 0.0, + mask_value: Optional[float] = 0.0, + ): + # Sanity checks + assert mask_value is None or isinstance( + mask_value, numbers.Number + ), f"mask_value (type: {type(mask_value)}) must be None or a number" + if freq_mask_n > 0: + assert freq_mask_f > 0, ( + f"freq_mask_F ({freq_mask_f}) " + f"must be larger than 0 when doing freq masking." + ) + if time_mask_n > 0: + assert time_mask_t > 0, ( + f"time_mask_T ({time_mask_t}) must be larger than 0 when " + f"doing time masking." + ) + + self.time_warp_w = time_warp_w + self.freq_mask_n = freq_mask_n + self.freq_mask_f = freq_mask_f + self.time_mask_n = time_mask_n + self.time_mask_t = time_mask_t + self.time_mask_p = time_mask_p + self.mask_value = mask_value + + def __repr__(self): + return ( + self.__class__.__name__ + + "(" + + ", ".join( + [ + f"time_warp_w={self.time_warp_w}", + f"freq_mask_n={self.freq_mask_n}", + f"freq_mask_f={self.freq_mask_f}", + f"time_mask_n={self.time_mask_n}", + f"time_mask_t={self.time_mask_t}", + f"time_mask_p={self.time_mask_p}", + ] + ) + + ")" + ) + + def __call__(self, spectrogram): + assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor." + + distorted = spectrogram.copy() # make a copy of input spectrogram. + num_frames = spectrogram.shape[0] # or 'tau' in the paper. + num_freqs = spectrogram.shape[1] # or 'miu' in the paper. + mask_value = self.mask_value + + if mask_value is None: # if no value was specified, use local mean. + mask_value = spectrogram.mean() + + if num_frames == 0: + return spectrogram + + if num_freqs < self.freq_mask_f: + return spectrogram + + if self.time_warp_w > 0: + if 2 * self.time_warp_w < num_frames: + import cv2 + + w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w) + w = np.random.randint(0, self.time_warp_w) + upper, lower = distorted[:w0, :], distorted[w0:, :] + upper = cv2.resize( + upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR + ) + lower = cv2.resize( + lower, + dsize=(num_freqs, num_frames - w0 - w), + interpolation=cv2.INTER_LINEAR, + ) + distorted = np.concatenate((upper, lower), axis=0) + + for _i in range(self.freq_mask_n): + f = np.random.randint(0, self.freq_mask_f) + f0 = np.random.randint(0, num_freqs - f) + if f != 0: + distorted[:, f0 : f0 + f] = mask_value + + max_time_mask_t = min( + self.time_mask_t, math.floor(num_frames * self.time_mask_p) + ) + if max_time_mask_t < 1: + return distorted + + for _i in range(self.time_mask_n): + t = np.random.randint(0, max_time_mask_t) + t0 = np.random.randint(0, num_frames - t) + if t != 0: + distorted[t0 : t0 + t, :] = mask_value + + return distorted diff --git a/fairseq/data/audio/feature_transforms/utterance_cmvn.py b/fairseq/data/audio/feature_transforms/utterance_cmvn.py new file mode 100644 index 0000000..6bbd0ae --- /dev/null +++ b/fairseq/data/audio/feature_transforms/utterance_cmvn.py @@ -0,0 +1,40 @@ +import numpy as np +from fairseq.data.audio.feature_transforms import ( + AudioFeatureTransform, + register_audio_feature_transform, +) + + +@register_audio_feature_transform("utterance_cmvn") +class UtteranceCMVN(AudioFeatureTransform): + """Utterance-level CMVN (cepstral mean and variance normalization)""" + + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return UtteranceCMVN( + _config.get("norm_means", True), + _config.get("norm_vars", True), + ) + + def __init__(self, norm_means=True, norm_vars=True): + self.norm_means, self.norm_vars = norm_means, norm_vars + + def __repr__(self): + return ( + self.__class__.__name__ + + f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})" + ) + + def __call__(self, x): + mean = x.mean(axis=0) + square_sums = (x ** 2).sum(axis=0) + + if self.norm_means: + x = np.subtract(x, mean) + if self.norm_vars: + var = square_sums / x.shape[0] - mean ** 2 + std = np.sqrt(np.maximum(var, 1e-10)) + x = np.divide(x, std) + + return x diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py new file mode 100644 index 0000000..86270d9 --- /dev/null +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -0,0 +1,192 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +import sys + +import numpy as np +import torch +import torch.nn.functional as F + +from .. import FairseqDataset + + +logger = logging.getLogger(__name__) + + +class RawAudioDataset(FairseqDataset): + def __init__( + self, + sample_rate, + max_sample_size=None, + min_sample_size=None, + shuffle=True, + min_length=0, + pad=False, + normalize=False, + ): + super().__init__() + + self.sample_rate = sample_rate + self.sizes = [] + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) + self.min_sample_size = min_sample_size + self.min_length = min_length + self.pad = pad + self.shuffle = shuffle + self.normalize = normalize + + def __getitem__(self, index): + raise NotImplementedError() + + def __len__(self): + return len(self.sizes) + + def postprocess(self, feats, curr_sample_rate): + if feats.dim() == 2: + feats = feats.mean(-1) + + if curr_sample_rate != self.sample_rate: + raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}") + + assert feats.dim() == 1, feats.dim() + + if self.normalize: + with torch.no_grad(): + feats = F.layer_norm(feats, feats.shape) + return feats + + def crop_to_max_size(self, wav, target_size): + size = len(wav) + diff = size - target_size + if diff <= 0: + return wav + + start = np.random.randint(0, diff + 1) + end = size - diff + start + return wav[start:end] + + def collater(self, samples): + samples = [s for s in samples if s["source"] is not None] + if len(samples) == 0: + return {} + + sources = [s["source"] for s in samples] + sizes = [len(s) for s in sources] + + if self.pad: + target_size = min(max(sizes), self.max_sample_size) + else: + target_size = min(min(sizes), self.max_sample_size) + + collated_sources = sources[0].new_zeros(len(sources), target_size) + padding_mask = ( + torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None + ) + for i, (source, size) in enumerate(zip(sources, sizes)): + diff = size - target_size + if diff == 0: + collated_sources[i] = source + elif diff < 0: + assert self.pad + collated_sources[i] = torch.cat( + [source, source.new_full((-diff,), 0.0)] + ) + padding_mask[i, diff:] = True + else: + collated_sources[i] = self.crop_to_max_size(source, target_size) + + input = {"source": collated_sources} + if self.pad: + input["padding_mask"] = padding_mask + return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input} + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + if self.pad: + return self.sizes[index] + return min(self.sizes[index], self.max_sample_size) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.sizes) + return np.lexsort(order)[::-1] + + +class FileAudioDataset(RawAudioDataset): + def __init__( + self, + manifest_path, + sample_rate, + max_sample_size=None, + min_sample_size=None, + shuffle=True, + min_length=0, + pad=False, + normalize=False, + ): + super().__init__( + sample_rate=sample_rate, + max_sample_size=max_sample_size, + min_sample_size=min_sample_size, + shuffle=shuffle, + min_length=min_length, + pad=pad, + normalize=normalize, + ) + + self.fnames = [] + self.skipped = [] + + skipped = 0 + count = 0 + with open(manifest_path, "r") as f: + self.root_dir = f.readline().strip() + for line in f: + count += 1 + items = line.strip().split("\t") + #assert len(items) == 2, line + sz = int(items[1]) + if len(items) == 3: + sr = int(items[2]) + assert sz % (sr / self.sample_rate) == 0 + sz = sz / (sr / self.sample_rate) + if min_length is not None and sz < min_length: + skipped += 1 + self.skipped.append(count) + continue + if pad and max_sample_size is not None and sz > max_sample_size: + skipped += 1 + self.skipped.append(count) + continue + self.fnames.append(items[0]) + self.sizes.append(int(sz)) + self.sizes = np.array(self.sizes) + logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") + + def __getitem__(self, index): + import soundfile as sf + + fname = os.path.join(self.root_dir, self.fnames[index]) + wav, curr_sample_rate = sf.read(fname) + #wav, curr_sample_rate = librosa.load(fname, sr=self.sample_rate) + feats = torch.from_numpy(wav).float() + feats = self.postprocess(feats, curr_sample_rate) + return {"id": index, "source": feats} diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py new file mode 100644 index 0000000..39d22c7 --- /dev/null +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -0,0 +1,528 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import csv +import io +import logging +import os.path as op +import re +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from fairseq.data import ( + ConcatDataset, + Dictionary, + FairseqDataset, + ResamplingDataset, + data_utils as fairseq_data_utils, +) +from fairseq.data.audio.audio_utils import get_fbank, get_waveform +from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform + + +logger = logging.getLogger(__name__) + + +class S2TDataConfig(object): + """Wrapper class for data config YAML""" + + def __init__(self, yaml_path): + try: + import yaml + except ImportError: + print("Please install PyYAML to load YAML files for " "S2T data config") + self.config = {} + if op.isfile(yaml_path): + try: + with open(yaml_path) as f: + self.config = yaml.load(f, Loader=yaml.FullLoader) + except Exception as e: + logger.info(f"Failed to load config from {yaml_path}: {e}") + else: + logger.info(f"Cannot find {yaml_path}") + + @property + def vocab_filename(self): + """fairseq vocabulary file under data root""" + return self.config.get("vocab_filename", "dict.txt") + + @property + def shuffle(self) -> bool: + """Shuffle dataset samples before batching""" + return self.config.get("shuffle", False) + + @property + def pre_tokenizer(self) -> Dict: + """Pre-tokenizer to apply before subword tokenization. Returning + a dictionary with `tokenizer` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + return self.config.get("pre_tokenizer", {"tokenizer": None}) + + @property + def bpe_tokenizer(self) -> Dict: + """Subword tokenizer to apply after pre-tokenization. Returning + a dictionary with `bpe` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + return self.config.get("bpe_tokenizer", {"bpe": None}) + + @property + def prepend_tgt_lang_tag(self) -> bool: + """Prepend target lang ID token as the target BOS (e.g. for to-many + multilingual setting). During inference, this requires `--prefix-size 1` + to force BOS to be lang ID token.""" + return self.config.get("prepend_tgt_lang_tag", False) + + @property + def input_feat_per_channel(self): + """The dimension of input features (per audio channel)""" + return self.config.get("input_feat_per_channel", 80) + + @property + def input_channels(self): + """The number of channels in the input audio""" + return self.config.get("input_channels", 1) + + @property + def sampling_alpha(self): + """Hyper-parameter alpha = 1/T for temperature-based resampling. + (alpha = 1 for no resampling)""" + return self.config.get("sampling_alpha", 1.0) + + @property + def use_audio_input(self): + """Needed by the dataset loader to see if the model requires + raw audio as inputs.""" + return self.config.get("use_audio_input", False) + + @property + def audio_root(self): + """Audio paths in the manifest TSV can be relative and this provides + the root path. Set this to empty string when using absolute paths.""" + return self.config.get("audio_root", "") + + def get_feature_transforms(self, split, is_train): + """Split-specific feature transforms. Allowing train set wildcard `_train`, + evaluation set wildcard `_eval` and general wildcard `*` for matching.""" + from copy import deepcopy + + cfg = deepcopy(self.config) + _cur = cfg.get("transforms", {}) + cur = _cur.get(split) + cur = _cur.get("_train") if cur is None and is_train else cur + cur = _cur.get("_eval") if cur is None and not is_train else cur + cur = _cur.get("*") if cur is None else cur + cfg["transforms"] = cur + return cfg + + +def is_npy_data(data: bytes) -> bool: + return data[0] == 147 and data[1] == 78 + + +def is_flac_or_wav_data(data: bytes) -> bool: + is_flac = data[0] == 102 and data[1] == 76 + is_wav = data[0] == 82 and data[1] == 73 + return is_flac or is_wav + + +def read_from_uncompressed_zip(file_path, offset, file_size) -> bytes: + with open(file_path, "rb") as f: + f.seek(offset) + data = f.read(file_size) + return data + + +def get_features_from_npy_or_audio(path): + ext = op.splitext(op.basename(path))[1] + if ext not in {".npy", ".flac", ".wav"}: + raise ValueError(f'Unsupported file format for "{path}"') + return np.load(path) if ext == ".npy" else get_fbank(path) + + +def get_features_or_waveform_from_uncompressed_zip( + path, byte_offset, byte_size, need_waveform=False +): + assert path.endswith(".zip") + data = read_from_uncompressed_zip(path, byte_offset, byte_size) + f = io.BytesIO(data) + if is_npy_data(data): + features_or_waveform = np.load(f) + elif is_flac_or_wav_data(data): + features_or_waveform = get_waveform(f)[0] if need_waveform else get_fbank(f) + else: + raise ValueError(f'Unknown file format for "{path}"') + return features_or_waveform + + +def get_features_or_waveform(path: str, need_waveform=False): + """Get speech features from .npy file or waveform from .wav/.flac file. + The file may be inside an uncompressed ZIP file and is accessed via byte + offset and length. + + Args: + path (str): File path in the format of "<.npy/.wav/.flac path>" or + "::". + need_waveform (bool): return waveform instead of features. + + Returns: + features_or_waveform (numpy.ndarray): speech features or waveform. + """ + _path, *extra = path.split(":") + if not op.exists(_path): + raise FileNotFoundError(f"File not found: {_path}") + + if len(extra) == 0: + if need_waveform: + return get_waveform(_path) + return get_features_from_npy_or_audio(_path) + elif len(extra) == 2: + extra = [int(i) for i in extra] + features_or_waveform = get_features_or_waveform_from_uncompressed_zip( + _path, extra[0], extra[1], need_waveform=need_waveform + ) + else: + raise ValueError(f"Invalid path: {path}") + + return features_or_waveform + + +def _collate_frames( + frames: List[torch.Tensor], is_audio_input: bool = False +) -> torch.Tensor: + """ + Convert a list of 2D frames into a padded 3D tensor + Args: + frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is + length of i-th frame and f_dim is static dimension of features + Returns: + 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] + """ + max_len = max(frame.size(0) for frame in frames) + if is_audio_input: + out = frames[0].new_zeros((len(frames), max_len)) + else: + out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) + for i, v in enumerate(frames): + out[i, : v.size(0)] = v + return out + + +class SpeechToTextDataset(FairseqDataset): + LANG_TAG_TEMPLATE = "" + + def __init__( + self, + split: str, + is_train_split: bool, + data_cfg: S2TDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + ): + self.split, self.is_train_split = split, is_train_split + self.data_cfg = data_cfg + self.audio_paths, self.n_frames = audio_paths, n_frames + self.n_samples = len(audio_paths) + assert len(n_frames) == self.n_samples > 0 + assert src_texts is None or len(src_texts) == self.n_samples + assert tgt_texts is None or len(tgt_texts) == self.n_samples + assert speakers is None or len(speakers) == self.n_samples + assert src_langs is None or len(src_langs) == self.n_samples + assert tgt_langs is None or len(tgt_langs) == self.n_samples + assert ids is None or len(ids) == self.n_samples + assert (tgt_dict is None and tgt_texts is None) or ( + tgt_dict is not None and tgt_texts is not None + ) + self.src_texts, self.tgt_texts = src_texts, tgt_texts + self.src_langs, self.tgt_langs = src_langs, tgt_langs + self.tgt_dict = tgt_dict + self.check_tgt_lang_tag() + self.ids = ids + self.shuffle = data_cfg.shuffle if is_train_split else False + + self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict( + self.data_cfg.get_feature_transforms(split, is_train_split) + ) + + self.pre_tokenizer = pre_tokenizer + self.bpe_tokenizer = bpe_tokenizer + + logger.info(self.__repr__()) + + def __repr__(self): + return ( + self.__class__.__name__ + + f'(split="{self.split}", n_samples={self.n_samples}, ' + f"prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, " + f"shuffle={self.shuffle}, transforms={self.feature_transforms})" + ) + + @classmethod + def is_lang_tag(cls, token): + pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") + return re.match(pattern, token) + + def check_tgt_lang_tag(self): + if self.data_cfg.prepend_tgt_lang_tag: + assert self.tgt_langs is not None and self.tgt_dict is not None + tgt_lang_tags = [ + self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs) + ] + assert all(t in self.tgt_dict for t in tgt_lang_tags) + + def tokenize_text(self, text: str): + if self.pre_tokenizer is not None: + text = self.pre_tokenizer.encode(text) + if self.bpe_tokenizer is not None: + text = self.bpe_tokenizer.encode(text) + return text + + def __getitem__( + self, index: int + ) -> Tuple[int, torch.Tensor, Optional[torch.Tensor]]: + source = get_features_or_waveform( + self.audio_paths[index], need_waveform=self.data_cfg.use_audio_input + ) + if self.feature_transforms is not None: + assert not self.data_cfg.use_audio_input + source = self.feature_transforms(source) + source = torch.from_numpy(source).float() + + target = None + if self.tgt_texts is not None: + tokenized = self.tokenize_text(self.tgt_texts[index]) + target = self.tgt_dict.encode_line( + tokenized, add_if_not_exist=False, append_eos=True + ).long() + if self.data_cfg.prepend_tgt_lang_tag: + lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index]) + lang_tag_idx = self.tgt_dict.index(lang_tag) + target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) + return index, source, target + + def __len__(self): + return self.n_samples + + def collater(self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]) -> Dict: + if len(samples) == 0: + return {} + indices = torch.tensor([i for i, _, _ in samples], dtype=torch.long) + frames = _collate_frames( + [s for _, s, _ in samples], self.data_cfg.use_audio_input + ) + # sort samples by descending number of frames + n_frames = torch.tensor([s.size(0) for _, s, _ in samples], dtype=torch.long) + n_frames, order = n_frames.sort(descending=True) + indices = indices.index_select(0, order) + frames = frames.index_select(0, order) + + target, target_lengths = None, None + prev_output_tokens = None + ntokens = None + if self.tgt_texts is not None: + target = fairseq_data_utils.collate_tokens( + [t for _, _, t in samples], + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=False, + ) + target = target.index_select(0, order) + target_lengths = torch.tensor( + [t.size(0) for _, _, t in samples], dtype=torch.long + ).index_select(0, order) + prev_output_tokens = fairseq_data_utils.collate_tokens( + [t for _, _, t in samples], + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=True, + ) + prev_output_tokens = prev_output_tokens.index_select(0, order) + ntokens = sum(t.size(0) for _, _, t in samples) + + out = { + "id": indices, + "net_input": { + "src_tokens": frames, + "src_lengths": n_frames, + "prev_output_tokens": prev_output_tokens, + }, + "target": target, + "target_lengths": target_lengths, + "ntokens": ntokens, + "nsentences": len(samples), + } + return out + + def num_tokens(self, index): + return self.n_frames[index] + + def size(self, index): + t_len = 0 + if self.tgt_texts is not None: + tokenized = self.tokenize_text(self.tgt_texts[index]) + t_len = len(tokenized.split(" ")) + return self.n_frames[index], t_len + + @property + def sizes(self): + return np.array(self.n_frames) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + # first by descending order of # of frames then by original/random order + order.append([-n for n in self.n_frames]) + return np.lexsort(order) + + def prefetch(self, indices): + raise False + + +class SpeechToTextDatasetCreator(object): + # mandatory columns + KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames" + KEY_TGT_TEXT = "tgt_text" + # optional columns + KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text" + KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang" + # default values + DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = "" + + @classmethod + def _from_list( + cls, + split_name: str, + is_train_split, + samples: List[List[Dict]], + data_cfg: S2TDataConfig, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + ) -> SpeechToTextDataset: + audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], [] + speakers, src_langs, tgt_langs = [], [], [] + for s in samples: + ids.extend([ss[cls.KEY_ID] for ss in s]) + audio_paths.extend( + [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s] + ) + n_frames.extend([int(ss[cls.KEY_N_FRAMES]) for ss in s]) + tgt_texts.extend([ss[cls.KEY_TGT_TEXT] for ss in s]) + src_texts.extend( + [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s] + ) + speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]) + src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]) + tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]) + return SpeechToTextDataset( + split_name, + is_train_split, + data_cfg, + audio_paths, + n_frames, + src_texts, + tgt_texts, + speakers, + src_langs, + tgt_langs, + ids, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + ) + + @classmethod + def _get_size_ratios(cls, ids: List[str], sizes: List[int], alpha: float = 1.0): + """Size ratios for temperature-based sampling + (https://arxiv.org/abs/1907.05019)""" + _sizes = np.array(sizes) + prob = _sizes / _sizes.sum() + smoothed_prob = prob ** alpha + smoothed_prob = smoothed_prob / smoothed_prob.sum() + size_ratio = (smoothed_prob * _sizes.sum()) / _sizes + + o_str = str({_i: f"{prob[i]:.3f}" for i, _i in enumerate(ids)}) + logger.info(f"original sampling probability: {o_str}") + p_str = str({_i: f"{smoothed_prob[i]:.3f}" for i, _i in enumerate(ids)}) + logger.info(f"balanced sampling probability: {p_str}") + sr_str = str({_id: f"{size_ratio[i]:.3f}" for i, _id in enumerate(ids)}) + logger.info(f"balanced sampling size ratio: {sr_str}") + return size_ratio.tolist() + + @classmethod + def from_tsv( + cls, + root: str, + data_cfg: S2TDataConfig, + splits: str, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split: bool, + epoch: int, + seed: int, + ) -> SpeechToTextDataset: + samples = [] + _splits = splits.split(",") + for split in _splits: + tsv_path = op.join(root, f"{split}.tsv") + if not op.isfile(tsv_path): + raise FileNotFoundError(f"Dataset not found: {tsv_path}") + with open(tsv_path) as f: + reader = csv.DictReader( + f, + delimiter="\t", + quotechar=None, + doublequote=False, + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + samples.append([dict(e) for e in reader]) + assert len(samples) > 0 + + datasets = [ + cls._from_list( + name, + is_train_split, + [s], + data_cfg, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + ) + for name, s in zip(_splits, samples) + ] + + if is_train_split and len(_splits) > 1 and data_cfg.sampling_alpha != 1.0: + # temperature-based sampling + size_ratios = cls._get_size_ratios( + _splits, [len(s) for s in samples], alpha=data_cfg.sampling_alpha + ) + datasets = [ + ResamplingDataset( + d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) + ) + for d, r in zip(datasets, size_ratios) + ] + return ConcatDataset(datasets) diff --git a/fairseq/data/base_wrapper_dataset.py b/fairseq/data/base_wrapper_dataset.py new file mode 100644 index 0000000..134d398 --- /dev/null +++ b/fairseq/data/base_wrapper_dataset.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch.utils.data.dataloader import default_collate + +from . import FairseqDataset + + +class BaseWrapperDataset(FairseqDataset): + def __init__(self, dataset): + super().__init__() + self.dataset = dataset + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if hasattr(self.dataset, "collater"): + return self.dataset.collater(samples) + else: + return default_collate(samples) + + @property + def sizes(self): + return self.dataset.sizes + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) + + def ordered_indices(self): + return self.dataset.ordered_indices() + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def attr(self, attr: str, index: int): + return self.dataset.attr(attr, index) + + def prefetch(self, indices): + self.dataset.prefetch(indices) + + def get_batch_shapes(self): + return self.dataset.get_batch_shapes() + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + return self.dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + + def filter_indices_by_size(self, indices, max_sizes): + return self.dataset.filter_indices_by_size(indices, max_sizes) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return self.dataset.can_reuse_epoch_itr_across_epochs + + def set_epoch(self, epoch): + super().set_epoch(epoch) + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) diff --git a/fairseq/data/concat_dataset.py b/fairseq/data/concat_dataset.py new file mode 100644 index 0000000..01a4078 --- /dev/null +++ b/fairseq/data/concat_dataset.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import bisect + +import numpy as np +from torch.utils.data.dataloader import default_collate + +from . import FairseqDataset + + +class ConcatDataset(FairseqDataset): + @staticmethod + def cumsum(sequence, sample_ratios): + r, s = [], 0 + for e, ratio in zip(sequence, sample_ratios): + curr_len = int(ratio * len(e)) + r.append(curr_len + s) + s += curr_len + return r + + def __init__(self, datasets, sample_ratios=1): + super(ConcatDataset, self).__init__() + assert len(datasets) > 0, "datasets should not be an empty iterable" + self.datasets = list(datasets) + if isinstance(sample_ratios, int): + sample_ratios = [sample_ratios] * len(self.datasets) + self.sample_ratios = sample_ratios + self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) + self.real_sizes = [len(d) for d in self.datasets] + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + return self.datasets[dataset_idx][sample_idx] + + def _get_dataset_and_sample_index(self, idx: int): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + sample_idx = sample_idx % self.real_sizes[dataset_idx] + return dataset_idx, sample_idx + + def collater(self, samples, **extra_args): + # For now only supports datasets with same underlying collater implementations + if hasattr(self.datasets[0], "collater"): + return self.datasets[0].collater(samples, **extra_args) + else: + return default_collate(samples, **extra_args) + + def size(self, idx: int): + """ + Return an example's size as a float or tuple. + """ + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + return self.datasets[dataset_idx].size(sample_idx) + + def num_tokens(self, index: int): + return np.max(self.size(index)) + + def attr(self, attr: str, index: int): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, index) + return getattr(self.datasets[dataset_idx], attr, None) + + @property + def sizes(self): + _dataset_sizes = [] + for ds, sr in zip(self.datasets, self.sample_ratios): + if isinstance(ds.sizes, np.ndarray): + _dataset_sizes.append(np.tile(ds.sizes, sr)) + else: + # Only support underlying dataset with single size array. + assert isinstance(ds.sizes, list) + _dataset_sizes.append(np.tile(ds.sizes[0], sr)) + return np.concatenate(_dataset_sizes) + + @property + def supports_prefetch(self): + return all(d.supports_prefetch for d in self.datasets) + + def ordered_indices(self): + """ + Returns indices sorted by length. So less padding is needed. + """ + if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1: + # special handling for concatenating lang_pair_datasets + indices = np.arange(len(self)) + sizes = self.sizes + tgt_sizes = ( + sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + ) + src_sizes = ( + sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + ) + # sort by target length, then source length + if tgt_sizes is not None: + indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(src_sizes[indices], kind="mergesort")] + else: + return np.argsort(self.sizes) + + def prefetch(self, indices): + frm = 0 + for to, ds in zip(self.cumulative_sizes, self.datasets): + real_size = len(ds) + if getattr(ds, "supports_prefetch", False): + ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) + frm = to + + @property + def can_reuse_epoch_itr_across_epochs(self): + return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.datasets: + if hasattr(ds, "set_epoch"): + ds.set_epoch(epoch) diff --git a/fairseq/data/concat_sentences_dataset.py b/fairseq/data/concat_sentences_dataset.py new file mode 100644 index 0000000..625a293 --- /dev/null +++ b/fairseq/data/concat_sentences_dataset.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import FairseqDataset + + +class ConcatSentencesDataset(FairseqDataset): + def __init__(self, *datasets): + super().__init__() + self.datasets = datasets + assert all( + len(ds) == len(datasets[0]) for ds in datasets + ), "datasets must have the same length" + + def __getitem__(self, index): + return torch.cat([ds[index] for ds in self.datasets]) + + def __len__(self): + return len(self.datasets[0]) + + def collater(self, samples): + return self.datasets[0].collater(samples) + + @property + def sizes(self): + return sum(ds.sizes for ds in self.datasets) + + def num_tokens(self, index): + return sum(ds.num_tokens(index) for ds in self.datasets) + + def size(self, index): + return sum(ds.size(index) for ds in self.datasets) + + def ordered_indices(self): + return self.datasets[0].ordered_indices() + + @property + def supports_prefetch(self): + return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets) + + def prefetch(self, indices): + for ds in self.datasets: + if getattr(ds, "supports_prefetch", False): + ds.prefetch(indices) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.datasets: + if hasattr(ds, "set_epoch"): + ds.set_epoch(epoch) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py new file mode 100644 index 0000000..81f4573 --- /dev/null +++ b/fairseq/data/data_utils.py @@ -0,0 +1,499 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable +import contextlib +import itertools +import logging +import os +import warnings +from typing import Optional, Tuple + +import numpy as np +import torch + + +logger = logging.getLogger(__name__) + + +def infer_language_pair(path): + """Infer language pair from filename: .-.(...).idx""" + src, dst = None, None + for filename in os.listdir(path): + parts = filename.split(".") + if len(parts) >= 3 and len(parts[1].split("-")) == 2: + return parts[1].split("-") + return src, dst + + +def collate_tokens( + values, + pad_idx, + eos_idx=None, + left_pad=False, + move_eos_to_beginning=False, + pad_to_length=None, + pad_to_multiple=1, +): + """Convert a list of 1d tensors into a padded 2d tensor.""" + size = max(v.size(0) for v in values) + size = size if pad_to_length is None else max(size, pad_to_length) + if pad_to_multiple != 1 and size % pad_to_multiple != 0: + size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + res = values[0].new(len(values), size).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if move_eos_to_beginning: + if eos_idx is None: + # if no eos_idx is specified, then use the last token in src + dst[0] = src[-1] + else: + dst[0] = eos_idx + dst[1:] = src[:-1] + else: + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) + return res + + +def load_indexed_dataset( + path, dictionary=None, dataset_impl=None, combine=False, default="cached" +): + """A helper function for loading indexed datasets. + + Args: + path (str): path to indexed dataset (e.g., 'data-bin/train') + dictionary (~fairseq.data.Dictionary): data dictionary + dataset_impl (str, optional): which dataset implementation to use. If + not provided, it will be inferred automatically. For legacy indexed + data we use the 'cached' implementation by default. + combine (bool, optional): automatically load and combine multiple + datasets. For example, if *path* is 'data-bin/train', then we will + combine 'data-bin/train', 'data-bin/train1', ... and return a + single ConcatDataset instance. + """ + from fairseq.data.concat_dataset import ConcatDataset + import fairseq.data.indexed_dataset as indexed_dataset + + datasets = [] + for k in itertools.count(): + path_k = path + (str(k) if k > 0 else "") + path_k = indexed_dataset.get_indexed_dataset_to_local(path_k) + + dataset_impl_k = dataset_impl + if dataset_impl_k is None: + dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k) + dataset = indexed_dataset.make_dataset( + path_k, + impl=dataset_impl_k or default, + fix_lua_indexing=True, + dictionary=dictionary, + ) + if dataset is None: + break + logger.info("loaded {} examples from: {}".format(len(dataset), path_k)) + datasets.append(dataset) + if not combine: + break + if len(datasets) == 0: + return None + elif len(datasets) == 1: + return datasets[0] + else: + return ConcatDataset(datasets) + + +@contextlib.contextmanager +def numpy_seed(seed, *addl_seeds): + """Context manager which seeds the NumPy PRNG with the specified seed and + restores the state afterward""" + if seed is None: + yield + return + if len(addl_seeds) > 0: + seed = int(hash((seed, *addl_seeds)) % 1e6) + state = np.random.get_state() + np.random.seed(seed) + try: + yield + finally: + np.random.set_state(state) + + +def collect_filtered(function, iterable, filtered): + """ + Similar to :func:`filter` but collects filtered elements in ``filtered``. + + Args: + function (callable): function that returns ``False`` for elements that + should be filtered + iterable (iterable): iterable to filter + filtered (list): list to store filtered elements + """ + for el in iterable: + if function(el): + yield el + else: + filtered.append(el) + + +def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False): + def compare_leq(a, b): + return a <= b if not isinstance(a, tuple) else max(a) <= b + + def check_size(idx): + if isinstance(max_positions, float) or isinstance(max_positions, int): + return size_fn(idx) <= max_positions + elif isinstance(max_positions, dict): + idx_size = size_fn(idx) + assert isinstance(idx_size, dict) + intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) + return all( + all( + a is None or b is None or a <= b + for a, b in zip(idx_size[key], max_positions[key]) + ) + for key in intersect_keys + ) + else: + # Hacky as heck, for the specific case of multilingual training with RoundRobin. + if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple): + return all( + a is None or b is None or compare_leq(a, b) + for a, b in zip(size_fn(idx).values(), max_positions) + ) + # For MultiCorpusSampledDataset, will generalize it later + if not isinstance(size_fn(idx), Iterable): + return all(size_fn(idx) <= b for b in max_positions) + return all( + a is None or b is None or a <= b + for a, b in zip(size_fn(idx), max_positions) + ) + + ignored = [] + itr = collect_filtered(check_size, indices, ignored) + indices = np.fromiter(itr, dtype=np.int64, count=-1) + return indices, ignored + + +def filter_by_size(indices, dataset, max_positions, raise_exception=False): + """ + [deprecated] Filter indices based on their size. + Use `FairseqDataset::filter_indices_by_size` instead. + + Args: + indices (List[int]): ordered list of dataset indices + dataset (FairseqDataset): fairseq dataset instance + max_positions (tuple): filter elements larger than this size. + Comparisons are done component-wise. + raise_exception (bool, optional): if ``True``, raise an exception if + any elements are filtered (default: False). + """ + warnings.warn( + "data_utils.filter_by_size is deprecated. " + "Use `FairseqDataset::filter_indices_by_size` instead.", + stacklevel=2, + ) + if isinstance(max_positions, float) or isinstance(max_positions, int): + if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray): + ignored = indices[dataset.sizes[indices] > max_positions].tolist() + indices = indices[dataset.sizes[indices] <= max_positions] + elif ( + hasattr(dataset, "sizes") + and isinstance(dataset.sizes, list) + and len(dataset.sizes) == 1 + ): + ignored = indices[dataset.sizes[0][indices] > max_positions].tolist() + indices = indices[dataset.sizes[0][indices] <= max_positions] + else: + indices, ignored = _filter_by_size_dynamic( + indices, dataset.size, max_positions + ) + else: + indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions) + + if len(ignored) > 0 and raise_exception: + raise Exception( + ( + "Size of sample #{} is invalid (={}) since max_positions={}, " + "skip this example with --skip-invalid-size-inputs-valid-test" + ).format(ignored[0], dataset.size(ignored[0]), max_positions) + ) + if len(ignored) > 0: + logger.warning( + ( + "{} samples have invalid sizes and will be skipped, " + "max_positions={}, first few sample ids={}" + ).format(len(ignored), max_positions, ignored[:10]) + ) + return indices + + +def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes): + """Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if max_sizes is None: + return indices, [] + if type(max_sizes) in (int, float): + max_src_size, max_tgt_size = max_sizes, max_sizes + else: + max_src_size, max_tgt_size = max_sizes + if tgt_sizes is None: + ignored = indices[src_sizes[indices] > max_src_size] + else: + ignored = indices[ + (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size) + ] + if len(ignored) > 0: + if tgt_sizes is None: + indices = indices[src_sizes[indices] <= max_src_size] + else: + indices = indices[ + (src_sizes[indices] <= max_src_size) + & (tgt_sizes[indices] <= max_tgt_size) + ] + return indices, ignored.tolist() + + +def batch_by_size( + indices, + num_tokens_fn, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + fixed_shapes=None, +): + """ + Yield mini-batches of indices bucketed by size. Batches may contain + sequences of different lengths. + + Args: + indices (List[int]): ordered list of dataset indices + num_tokens_fn (callable): function that returns the number of tokens at + a given index + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + required_batch_size_multiple (int, optional): require batch size to + be less than N or a multiple of N (default: 1). + fixed_shapes (List[Tuple[int, int]], optional): if given, batches will + only be created with the given shapes. *max_sentences* and + *required_batch_size_multiple* will be ignored (default: None). + """ + try: + from fairseq.data.data_utils_fast import ( + batch_by_size_fast, + batch_fixed_shapes_fast, + ) + except ImportError: + raise ImportError( + "Please build Cython components with: `pip install --editable .` " + "or `python setup.py build_ext --inplace`" + ) + + max_tokens = max_tokens if max_tokens is not None else -1 + max_sentences = max_sentences if max_sentences is not None else -1 + bsz_mult = required_batch_size_multiple + + if not isinstance(indices, np.ndarray): + indices = np.fromiter(indices, dtype=np.int64, count=-1) + + if fixed_shapes is None: + return batch_by_size_fast( + indices, + num_tokens_fn, + max_tokens, + max_sentences, + bsz_mult, + ) + else: + fixed_shapes = np.array(fixed_shapes, dtype=np.int64) + sort_order = np.lexsort( + [ + fixed_shapes[:, 1].argsort(), # length + fixed_shapes[:, 0].argsort(), # bsz + ] + ) + fixed_shapes_sorted = fixed_shapes[sort_order] + return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted) + + +def post_process(sentence: str, symbol: str): + if symbol == "sentencepiece": + sentence = sentence.replace(" ", "").replace("\u2581", " ").strip() + elif symbol == "wordpiece": + sentence = sentence.replace(" ", "").replace("_", " ").strip() + elif symbol == "letter": + sentence = sentence.replace(" ", "").replace("|", " ").strip() + elif symbol == "_EOW": + sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() + elif symbol is not None and symbol != "none": + sentence = (sentence + " ").replace(symbol, "").rstrip() + return sentence + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +def get_mem_usage(): + try: + import psutil + + mb = 1024 * 1024 + return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb" + except ImportError: + return "N/A" + + +def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor: + bsz, max_lens = lens.size(0), torch.max(lens).item() + mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) + mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) + return mask + + +def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor: + return ~lengths_to_padding_mask(lens) diff --git a/fairseq/data/data_utils_fast.pyx b/fairseq/data/data_utils_fast.pyx new file mode 100644 index 0000000..38b4aa6 --- /dev/null +++ b/fairseq/data/data_utils_fast.pyx @@ -0,0 +1,123 @@ +# cython: language_level=3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +cimport cython +cimport numpy as np + +from libc.stdint cimport int32_t, int64_t + +ctypedef int64_t DTYPE_t + + +cdef _is_batch_full(int64_t num_sentences, int64_t num_tokens, int64_t max_tokens, int64_t max_sentences): + if num_sentences == 0: + return 0 + if max_sentences > 0 and num_sentences == max_sentences: + return 1 + if max_tokens > 0 and num_tokens > max_tokens: + return 1 + return 0 + + +@cython.cdivision(True) +cpdef list batch_by_size_fast( + np.ndarray[DTYPE_t, ndim=1] indices, + num_tokens_fn, + int64_t max_tokens, + int64_t max_sentences, + int32_t bsz_mult, +): + cdef int64_t sample_len = 0 + cdef list sample_lens = [] + cdef list batch = [] + cdef list batches = [] + cdef int64_t mod_len + cdef int64_t i + cdef int64_t idx + cdef int64_t num_tokens + cdef DTYPE_t[:] indices_view = indices + + for i in range(len(indices_view)): + idx = indices_view[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + assert max_tokens <= 0 or sample_len <= max_tokens, ( + "sentence at index {} of size {} exceeds max_tokens " + "limit of {}!".format(idx, sample_len, max_tokens) + ) + num_tokens = (len(batch) + 1) * sample_len + + if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences): + mod_len = max( + bsz_mult * (len(batch) // bsz_mult), + len(batch) % bsz_mult, + ) + batches.append(batch[:mod_len]) + batch = batch[mod_len:] + sample_lens = sample_lens[mod_len:] + sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 + batch.append(idx) + if len(batch) > 0: + batches.append(batch) + return batches + + +cdef _find_valid_shape( + DTYPE_t[:, :] shapes_view, + int64_t num_sentences, + int64_t num_tokens, +): + """Return index of first valid shape of -1 if none is found.""" + for i in range(shapes_view.shape[0]): + if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: + return i + return -1 + + +@cython.cdivision(True) +cpdef list batch_fixed_shapes_fast( + np.ndarray[DTYPE_t, ndim=1] indices, + num_tokens_fn, + np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted, +): + cdef int64_t sample_len = 0 + cdef list sample_lens = [] + cdef list batch = [] + cdef list batches = [] + cdef int64_t mod_len + cdef int64_t i + cdef int64_t idx + cdef int64_t num_tokens + cdef DTYPE_t[:] indices_view = indices + cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted + + for i in range(len(indices_view)): + idx = indices_view[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) + if shape_idx == -1: + batches.append(batch) + batch = [] + sample_lens = [] + sample_len = 0 + shapes_view = fixed_shapes_sorted + elif shape_idx > 0: + # small optimization for the next call to _find_valid_shape + shapes_view = shapes_view[shape_idx:] + + batch.append(idx) + + if len(batch) > 0: + batches.append(batch) + + return batches diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py new file mode 100644 index 0000000..a092bfb --- /dev/null +++ b/fairseq/data/dictionary.py @@ -0,0 +1,418 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from collections import Counter +from multiprocessing import Pool + +import torch +from fairseq import utils +from fairseq.binarizer import safe_readline +from fairseq.data import data_utils +from fairseq.file_io import PathManager +from fairseq.tokenizer import tokenize_line + + +class Dictionary(object): + """A mapping from symbols to consecutive integers""" + + def __init__( + self, + input_file=None, # begin keyword-only arguments + bos="", + pad="", + eos="", + unk="", + extra_special_symbols=None, + ): + + self.symbols = [] + self.count = [] + self.indices = {} + if input_file is not None and 'json' in input_file: + self.add_from_json(input_file) + else: + self.add_from_file(input_file) + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + self.bos_index = self.add_symbol(bos) + self.pad_index = self.add_symbol(pad) + self.eos_index = self.add_symbol(eos) + self.unk_index = self.add_symbol(unk) + if extra_special_symbols: + for s in extra_special_symbols: + self.add_symbol(s) + self.nspecial = len(self.symbols) + + def __eq__(self, other): + return self.indices == other.indices + + def __getitem__(self, idx): + if idx < len(self.symbols): + return self.symbols[idx] + return self.unk_word + + def __len__(self): + """Returns the number of symbols in the dictionary""" + return len(self.symbols) + + def __contains__(self, sym): + return sym in self.indices + + def index(self, sym): + """Returns the index of the specified symbol""" + assert isinstance(sym, str) + if sym in self.indices: + return self.indices[sym] + return self.unk_index + + def string( + self, + tensor, + bpe_symbol=None, + escape_unk=False, + extra_symbols_to_ignore=None, + unk_string=None, + ): + """Helper for converting a tensor of token indices to a string. + + Can optionally remove BPE symbols or escape words. + """ + if torch.is_tensor(tensor) and tensor.dim() == 2: + return "\n".join( + self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore) + for t in tensor + ) + + extra_symbols_to_ignore = set(extra_symbols_to_ignore or []) + extra_symbols_to_ignore.add(self.eos()) + + def token_string(i): + if i == self.unk(): + if unk_string is not None: + return unk_string + else: + return self.unk_string(escape_unk) + else: + return self[i] + + if hasattr(self, "bos_index"): + extra_symbols_to_ignore.add(self.bos()) + + sent = " ".join( + token_string(i) + for i in tensor + if utils.item(i) not in extra_symbols_to_ignore + ) + + return data_utils.post_process(sent, bpe_symbol) + + def unk_string(self, escape=False): + """Return unknown string, optionally escaped as: <>""" + if escape: + return "<{}>".format(self.unk_word) + else: + return self.unk_word + + def add_symbol(self, word, n=1, overwrite=False): + """Adds a word to the dictionary""" + if word in self.indices and not overwrite: + idx = self.indices[word] + self.count[idx] = self.count[idx] + n + return idx + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(n) + return idx + + def update(self, new_dict): + """Updates counts from new dictionary.""" + for word in new_dict.symbols: + idx2 = new_dict.indices[word] + if word in self.indices: + idx = self.indices[word] + self.count[idx] = self.count[idx] + new_dict.count[idx2] + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(new_dict.count[idx2]) + + def finalize(self, threshold=-1, nwords=-1, padding_factor=8): + """Sort symbols by frequency in descending order, ignoring special ones. + + Args: + - threshold defines the minimum word count + - nwords defines the total number of words in the final dictionary, + including special symbols + - padding_factor can be used to pad the dictionary size to be a + multiple of 8, which is important on some hardware (e.g., Nvidia + Tensor Cores). + """ + if nwords <= 0: + nwords = len(self) + + new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial))) + new_symbols = self.symbols[: self.nspecial] + new_count = self.count[: self.nspecial] + + c = Counter( + dict( + sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :])) + ) + ) + for symbol, count in c.most_common(nwords - self.nspecial): + if count >= threshold: + new_indices[symbol] = len(new_symbols) + new_symbols.append(symbol) + new_count.append(count) + else: + break + + assert len(new_symbols) == len(new_indices) + + self.count = list(new_count) + self.symbols = list(new_symbols) + self.indices = new_indices + + self.pad_to_multiple_(padding_factor) + + def pad_to_multiple_(self, padding_factor): + """Pad Dictionary size to be a multiple of *padding_factor*.""" + if padding_factor > 1: + i = 0 + while len(self) % padding_factor != 0: + symbol = "madeupword{:04d}".format(i) + self.add_symbol(symbol, n=0) + i += 1 + + def bos(self): + """Helper to get index of beginning-of-sentence symbol""" + return self.bos_index + + def pad(self): + """Helper to get index of pad symbol""" + return self.pad_index + + def eos(self): + """Helper to get index of end-of-sentence symbol""" + return self.eos_index + + def unk(self): + """Helper to get index of unk symbol""" + return self.unk_index + + @classmethod + def load(cls, f): + """Loads the dictionary from a text file with the format: + + ``` + + + ... + ``` + """ + d = cls() + d.add_from_file(f) + return d + + @classmethod + def load_from_json(cls, f): + d = cls() + d.add_from_json(f) + return d + + def add_from_json(self, f): + import json + if isinstance(f, str): + try: + with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd: + self.add_from_json(fd) + except FileNotFoundError as fnfe: + raise fnfe + except UnicodeError: + raise Exception( + "Incorrect encoding detected in {}, please " + "rebuild the dataset".format(f) + ) + return + + vocab = json.load(f) + for k, v in vocab.items(): + self.add_symbol(k) + + + def add_from_file(self, f): + """ + Loads a pre-existing dictionary from a text file and adds its symbols + to this instance. + """ + if isinstance(f, str): + try: + with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd: + self.add_from_file(fd) + except FileNotFoundError as fnfe: + raise fnfe + except UnicodeError: + raise Exception( + "Incorrect encoding detected in {}, please " + "rebuild the dataset".format(f) + ) + return + + lines = f.readlines() + indices_start_line = self._load_meta(lines) + + for line in lines[indices_start_line:]: + try: + line, field = line.rstrip().rsplit(" ", 1) + if field == "#fairseq:overwrite": + overwrite = True + line, field = line.rsplit(" ", 1) + else: + overwrite = False + count = int(field) + word = line + if word in self and not overwrite: + raise RuntimeError( + "Duplicate word found when loading Dictionary: '{}'. " + "Duplicate words can overwrite earlier ones by adding the " + "#fairseq:overwrite flag at the end of the corresponding row " + "in the dictionary file. If using the Camembert model, please " + "download an updated copy of the model file.".format(word) + ) + self.add_symbol(word, n=count, overwrite=overwrite) + except ValueError: + raise ValueError( + "Incorrect dictionary format, expected ' [flags]'" + ) + + def _save(self, f, kv_iterator): + if isinstance(f, str): + PathManager.mkdirs(os.path.dirname(f)) + with PathManager.open(f, "w", encoding="utf-8") as fd: + return self.save(fd) + for k, v in kv_iterator: + print("{} {}".format(k, v), file=f) + + def _get_meta(self): + return [], [] + + def _load_meta(self, lines): + return 0 + + def save(self, f): + """Stores dictionary into a text file""" + ex_keys, ex_vals = self._get_meta() + self._save( + f, + zip( + ex_keys + self.symbols[self.nspecial :], + ex_vals + self.count[self.nspecial :], + ), + ) + + def dummy_sentence(self, length): + t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() + t[-1] = self.eos() + return t + + def encode_line( + self, + line, + line_tokenizer=tokenize_line, + add_if_not_exist=True, + consumer=None, + append_eos=True, + reverse_order=False, + ): + words = line_tokenizer(line) + if reverse_order: + words = list(reversed(words)) + nwords = len(words) + ids = torch.IntTensor(nwords + 1 if append_eos else nwords) + + for i, word in enumerate(words): + if add_if_not_exist: + idx = self.add_symbol(word) + else: + idx = self.index(word) + if consumer is not None: + consumer(word, idx) + ids[i] = idx + if append_eos: + ids[nwords] = self.eos_index + return ids + + @staticmethod + def _add_file_to_dictionary_single_worker( + filename, tokenize, eos_word, worker_id=0, num_workers=1 + ): + counter = Counter() + with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f: + size = os.fstat(f.fileno()).st_size + chunk_size = size // num_workers + offset = worker_id * chunk_size + end = offset + chunk_size + f.seek(offset) + if offset > 0: + safe_readline(f) # drop first incomplete line + line = f.readline() + while line: + for word in tokenize(line): + counter.update([word]) + counter.update([eos_word]) + if f.tell() > end: + break + line = f.readline() + return counter + + @staticmethod + def add_file_to_dictionary(filename, dict, tokenize, num_workers): + def merge_result(counter): + for w, c in sorted(counter.items()): + dict.add_symbol(w, c) + + if num_workers > 1: + pool = Pool(processes=num_workers) + results = [] + for worker_id in range(num_workers): + results.append( + pool.apply_async( + Dictionary._add_file_to_dictionary_single_worker, + (filename, tokenize, dict.eos_word, worker_id, num_workers), + ) + ) + pool.close() + pool.join() + for r in results: + merge_result(r.get()) + else: + merge_result( + Dictionary._add_file_to_dictionary_single_worker( + filename, tokenize, dict.eos_word + ) + ) + + +class TruncatedDictionary(object): + def __init__(self, wrapped_dict, length): + self.__class__ = type( + wrapped_dict.__class__.__name__, + (self.__class__, wrapped_dict.__class__), + {}, + ) + self.__dict__ = wrapped_dict.__dict__ + self.wrapped_dict = wrapped_dict + self.length = min(len(self.wrapped_dict), length) + + def __len__(self): + return self.length + + def __getitem__(self, i): + if i < self.length: + return self.wrapped_dict[i] + return self.wrapped_dict.unk() diff --git a/fairseq/data/encoders/__init__.py b/fairseq/data/encoders/__init__.py new file mode 100644 index 0000000..2e807d8 --- /dev/null +++ b/fairseq/data/encoders/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib +import os + +from fairseq import registry + + +build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry( + "--tokenizer", + default=None, +) + + +build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry( + "--bpe", + default=None, +) + + +# automatically import any Python files in the encoders/ directory +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.data.encoders." + module) diff --git a/fairseq/data/encoders/byte_bpe.py b/fairseq/data/encoders/byte_bpe.py new file mode 100644 index 0000000..31e3a06 --- /dev/null +++ b/fairseq/data/encoders/byte_bpe.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass, field + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe +from fairseq.data.encoders.byte_utils import ( + SPACE, + SPACE_ESCAPE, + byte_encode, + smart_byte_decode, +) +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class ByteBpeConfig(FairseqDataclass): + sentencepiece_model_path: str = field( + default="???", metadata={"help": "path to sentencepiece model"} + ) + + +@register_bpe("byte_bpe", dataclass=ByteBpeConfig) +class ByteBPE(object): + def __init__(self, cfg): + vocab = file_utils.cached_path(cfg.sentencepiece_model_path) + try: + import sentencepiece as spm + + self.sp = spm.SentencePieceProcessor() + self.sp.Load(vocab) + except ImportError: + raise ImportError( + "Please install sentencepiece with: pip install sentencepiece" + ) + + def encode(self, x: str) -> str: + byte_encoded = byte_encode(x) + return SPACE.join(self.sp.EncodeAsPieces(byte_encoded)) + + @staticmethod + def decode(x: str) -> str: + unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) + return smart_byte_decode(unescaped) diff --git a/fairseq/data/encoders/byte_utils.py b/fairseq/data/encoders/byte_utils.py new file mode 100644 index 0000000..a305c08 --- /dev/null +++ b/fairseq/data/encoders/byte_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re + + +WHITESPACE_NORMALIZER = re.compile(r"\s+") +SPACE = chr(32) +SPACE_ESCAPE = chr(9601) +# excluding non-breaking space (160) here +PRINTABLE_LATIN = set( + list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1)) +) +BYTE_TO_BCHAR = { + b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256) +} +BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()} + + +def byte_encode(x: str) -> str: + normalized = WHITESPACE_NORMALIZER.sub(SPACE, x) + return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")]) + + +def byte_decode(x: str) -> str: + try: + return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8") + except ValueError: + return "" + + +def smart_byte_decode(x: str) -> str: + output = byte_decode(x) + if output == "": + # DP the best recovery (max valid chars) if it's broken + n_bytes = len(x) + f = [0 for _ in range(n_bytes + 1)] + pt = [0 for _ in range(n_bytes + 1)] + for i in range(1, n_bytes + 1): + f[i], pt[i] = f[i - 1], i - 1 + for j in range(1, min(4, i) + 1): + if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0: + f[i], pt[i] = f[i - j] + 1, i - j + cur_pt = n_bytes + while cur_pt > 0: + if f[cur_pt] == f[pt[cur_pt]] + 1: + output = byte_decode(x[pt[cur_pt] : cur_pt]) + output + cur_pt = pt[cur_pt] + return output diff --git a/fairseq/data/encoders/bytes.py b/fairseq/data/encoders/bytes.py new file mode 100644 index 0000000..f88f8f6 --- /dev/null +++ b/fairseq/data/encoders/bytes.py @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from fairseq.data.encoders import register_bpe +from fairseq.data.encoders.byte_utils import ( + SPACE, + SPACE_ESCAPE, + byte_encode, + smart_byte_decode, +) + + +@register_bpe("bytes") +class Bytes(object): + def __init__(self, *unused): + pass + + @staticmethod + def add_args(parser): + pass + + @staticmethod + def encode(x: str) -> str: + encoded = byte_encode(x) + escaped = encoded.replace(SPACE, SPACE_ESCAPE) + return SPACE.join(list(escaped)) + + @staticmethod + def decode(x: str) -> str: + unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) + return smart_byte_decode(unescaped) diff --git a/fairseq/data/encoders/characters.py b/fairseq/data/encoders/characters.py new file mode 100644 index 0000000..494ea21 --- /dev/null +++ b/fairseq/data/encoders/characters.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from fairseq.data.encoders import register_bpe + + +SPACE = chr(32) +SPACE_ESCAPE = chr(9601) + + +@register_bpe("characters") +class Characters(object): + def __init__(self, *unused): + pass + + @staticmethod + def add_args(parser): + pass + + @staticmethod + def encode(x: str) -> str: + escaped = x.replace(SPACE, SPACE_ESCAPE) + return SPACE.join(list(escaped)) + + @staticmethod + def decode(x: str) -> str: + return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) diff --git a/fairseq/data/encoders/fastbpe.py b/fairseq/data/encoders/fastbpe.py new file mode 100644 index 0000000..f7c2103 --- /dev/null +++ b/fairseq/data/encoders/fastbpe.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class fastBPEConfig(FairseqDataclass): + bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"}) + + +@register_bpe("fastbpe", dataclass=fastBPEConfig) +class fastBPE(object): + def __init__(self, cfg): + if cfg.bpe_codes is None: + raise ValueError("--bpe-codes is required for --bpe=fastbpe") + codes = file_utils.cached_path(cfg.bpe_codes) + try: + import fastBPE + + self.bpe = fastBPE.fastBPE(codes) + self.bpe_symbol = "@@ " + except ImportError: + raise ImportError("Please install fastBPE with: pip install fastBPE") + + def encode(self, x: str) -> str: + return self.bpe.apply([x])[0] + + def decode(self, x: str) -> str: + return (x + " ").replace(self.bpe_symbol, "").rstrip() diff --git a/fairseq/data/encoders/gpt2_bpe.py b/fairseq/data/encoders/gpt2_bpe.py new file mode 100644 index 0000000..e661426 --- /dev/null +++ b/fairseq/data/encoders/gpt2_bpe.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + +from .gpt2_bpe_utils import get_encoder + + +DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json" +DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" + + +@dataclass +class GPT2BPEConfig(FairseqDataclass): + gpt2_encoder_json: str = field( + default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"} + ) + gpt2_vocab_bpe: str = field( + default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"} + ) + + +@register_bpe("gpt2", dataclass=GPT2BPEConfig) +class GPT2BPE(object): + def __init__(self, cfg): + encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json) + vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe) + self.bpe = get_encoder(encoder_json, vocab_bpe) + + def encode(self, x: str) -> str: + return " ".join(map(str, self.bpe.encode(x))) + + def decode(self, x: str) -> str: + return self.bpe.decode( + [int(tok) if tok not in {"", ""} else tok for tok in x.split()] + ) + + def is_beginning_of_word(self, x: str) -> bool: + return self.decode(x).startswith(" ") diff --git a/fairseq/data/encoders/gpt2_bpe_utils.py b/fairseq/data/encoders/gpt2_bpe_utils.py new file mode 100644 index 0000000..688d4e3 --- /dev/null +++ b/fairseq/data/encoders/gpt2_bpe_utils.py @@ -0,0 +1,140 @@ +""" +Byte pair encoding utilities from GPT-2. + +Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py +Original license: MIT +""" + +import json +from functools import lru_cache + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2 ** 8): + if b not in bs: + bs.append(b) + cs.append(2 ** 8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges, errors="replace"): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + try: + import regex as re + + self.re = re + except ImportError: + raise ImportError("Please install regex with: pip install regex") + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = self.re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + for token in self.re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder.get(token, token) for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + "utf-8", errors=self.errors + ) + return text + + +def get_encoder(encoder_json_path, vocab_bpe_path): + with open(encoder_json_path, "r") as f: + encoder = json.load(f) + with open(vocab_bpe_path, "r", encoding="utf-8") as f: + bpe_data = f.read() + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] + return Encoder( + encoder=encoder, + bpe_merges=bpe_merges, + ) diff --git a/fairseq/data/encoders/hf_bert_bpe.py b/fairseq/data/encoders/hf_bert_bpe.py new file mode 100644 index 0000000..a41c059 --- /dev/null +++ b/fairseq/data/encoders/hf_bert_bpe.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import Optional + +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class BertBPEConfig(FairseqDataclass): + bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"}) + bpe_vocab_file: Optional[str] = field( + default=None, metadata={"help": "bpe vocab file"} + ) + + +@register_bpe("bert", dataclass=BertBPEConfig) +class BertBPE(object): + def __init__(self, cfg): + try: + from transformers import BertTokenizer + except ImportError: + raise ImportError( + "Please install transformers with: pip install transformers" + ) + + if cfg.bpe_vocab_file: + self.bert_tokenizer = BertTokenizer( + cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased + ) + else: + vocab_file_name = ( + "bert-base-cased" if cfg.bpe_cased else "bert-base-uncased" + ) + self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) + + def encode(self, x: str) -> str: + return " ".join(self.bert_tokenizer.tokenize(x)) + + def decode(self, x: str) -> str: + return self.bert_tokenizer.clean_up_tokenization( + self.bert_tokenizer.convert_tokens_to_string(x.split(" ")) + ) + + def is_beginning_of_word(self, x: str) -> bool: + return not x.startswith("##") diff --git a/fairseq/data/encoders/hf_byte_bpe.py b/fairseq/data/encoders/hf_byte_bpe.py new file mode 100644 index 0000000..92d2c39 --- /dev/null +++ b/fairseq/data/encoders/hf_byte_bpe.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class HuggingFaceByteLevelBPEConfig(FairseqDataclass): + bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"}) + bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"}) + bpe_add_prefix_space: bool = field( + default=False, metadata={"help": "add prefix space before encoding"} + ) + + +@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig) +class HuggingFaceByteLevelBPE(object): + def __init__(self, cfg): + try: + from tokenizers import ByteLevelBPETokenizer + except ImportError: + raise ImportError( + "Please install huggingface/tokenizers with: " "pip install tokenizers" + ) + + self.bpe = ByteLevelBPETokenizer( + cfg.bpe_vocab, + cfg.bpe_merges, + add_prefix_space=cfg.bpe_add_prefix_space, + ) + + def encode(self, x: str) -> str: + return " ".join(map(str, self.bpe.encode(x).ids)) + + def decode(self, x: str) -> str: + return self.bpe.decode( + [int(tok) if tok not in {"", ""} else tok for tok in x.split()] + ) + + def is_beginning_of_word(self, x: str) -> bool: + return self.decode(x).startswith(" ") diff --git a/fairseq/data/encoders/moses_tokenizer.py b/fairseq/data/encoders/moses_tokenizer.py new file mode 100644 index 0000000..fa004dd --- /dev/null +++ b/fairseq/data/encoders/moses_tokenizer.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq.data.encoders import register_tokenizer +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class MosesTokenizerConfig(FairseqDataclass): + source_lang: str = field(default="en", metadata={"help": "source language"}) + target_lang: str = field(default="en", metadata={"help": "target language"}) + moses_no_dash_splits: bool = field( + default=False, metadata={"help": "don't apply dash split rules"} + ) + moses_no_escape: bool = field( + default=False, + metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."}, + ) + + +@register_tokenizer("moses", dataclass=MosesTokenizerConfig) +class MosesTokenizer(object): + def __init__(self, cfg): + self.cfg = cfg + + try: + from sacremoses import MosesTokenizer, MosesDetokenizer + + self.tok = MosesTokenizer(cfg.source_lang) + self.detok = MosesDetokenizer(cfg.target_lang) + except ImportError: + raise ImportError( + "Please install Moses tokenizer with: pip install sacremoses" + ) + + def encode(self, x: str) -> str: + return self.tok.tokenize( + x, + aggressive_dash_splits=(not self.cfg.moses_no_dash_splits), + return_str=True, + escape=(not self.cfg.moses_no_escape), + ) + + def decode(self, x: str) -> str: + return self.detok.detokenize(x.split()) diff --git a/fairseq/data/encoders/nltk_tokenizer.py b/fairseq/data/encoders/nltk_tokenizer.py new file mode 100644 index 0000000..ee16471 --- /dev/null +++ b/fairseq/data/encoders/nltk_tokenizer.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.data.encoders import register_tokenizer + + +@register_tokenizer("nltk") +class NLTKTokenizer(object): + def __init__(self, *unused): + try: + from nltk.tokenize import word_tokenize + + self.word_tokenize = word_tokenize + except ImportError: + raise ImportError("Please install nltk with: pip install nltk") + + def encode(self, x: str) -> str: + return " ".join(self.word_tokenize(x)) + + def decode(self, x: str) -> str: + return x diff --git a/fairseq/data/encoders/sentencepiece_bpe.py b/fairseq/data/encoders/sentencepiece_bpe.py new file mode 100644 index 0000000..a76d46a --- /dev/null +++ b/fairseq/data/encoders/sentencepiece_bpe.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class SentencepieceConfig(FairseqDataclass): + sentencepiece_model: str = field( + default="???", metadata={"help": "path to sentencepiece model"} + ) + + +@register_bpe("sentencepiece", dataclass=SentencepieceConfig) +class SentencepieceBPE(object): + def __init__(self, cfg): + sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model) + try: + import sentencepiece as spm + + self.sp = spm.SentencePieceProcessor() + self.sp.Load(sentencepiece_model) + except ImportError: + raise ImportError( + "Please install sentencepiece with: pip install sentencepiece" + ) + + def encode(self, x: str) -> str: + return " ".join(self.sp.EncodeAsPieces(x)) + + def decode(self, x: str) -> str: + return x.replace(" ", "").replace("\u2581", " ").strip() + + def is_beginning_of_word(self, x: str) -> bool: + if x in ["", "", "", ""]: + # special elements are always considered beginnings + # HACK: this logic is already present in fairseq/tasks/masked_lm.py + # but these special tokens are also contained in the sentencepiece + # vocabulary which causes duplicate special tokens. This hack makes + # sure that they are all taken into account. + return True + return x.startswith("\u2581") diff --git a/fairseq/data/encoders/space_tokenizer.py b/fairseq/data/encoders/space_tokenizer.py new file mode 100644 index 0000000..7c7f644 --- /dev/null +++ b/fairseq/data/encoders/space_tokenizer.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re + +from fairseq.data.encoders import register_tokenizer + + +@register_tokenizer("space") +class SpaceTokenizer(object): + def __init__(self, *unused): + self.space_tok = re.compile(r"\s+") + + def encode(self, x: str) -> str: + return self.space_tok.sub(" ", x) + + def decode(self, x: str) -> str: + return x diff --git a/fairseq/data/encoders/subword_nmt_bpe.py b/fairseq/data/encoders/subword_nmt_bpe.py new file mode 100644 index 0000000..5d724d2 --- /dev/null +++ b/fairseq/data/encoders/subword_nmt_bpe.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class SubwordNMTBPEConfig(FairseqDataclass): + bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"}) + bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"}) + + +@register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig) +class SubwordNMTBPE(object): + def __init__(self, cfg): + if cfg.bpe_codes is None: + raise ValueError("--bpe-codes is required for --bpe=subword_nmt") + codes = file_utils.cached_path(cfg.bpe_codes) + try: + from subword_nmt import apply_bpe + + bpe_parser = apply_bpe.create_parser() + bpe_args = bpe_parser.parse_args( + [ + "--codes", + codes, + "--separator", + cfg.bpe_separator, + ] + ) + self.bpe = apply_bpe.BPE( + bpe_args.codes, + bpe_args.merges, + bpe_args.separator, + None, + bpe_args.glossaries, + ) + self.bpe_symbol = bpe_args.separator + " " + except ImportError: + raise ImportError( + "Please install subword_nmt with: pip install subword-nmt" + ) + + def encode(self, x: str) -> str: + return self.bpe.process_line(x) + + def decode(self, x: str) -> str: + return (x + " ").replace(self.bpe_symbol, "").rstrip() diff --git a/fairseq/data/encoders/utils.py b/fairseq/data/encoders/utils.py new file mode 100644 index 0000000..d93eb53 --- /dev/null +++ b/fairseq/data/encoders/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.data import encoders + + +def get_whole_word_mask(args, dictionary): + bpe = encoders.build_bpe(args) + if bpe is not None: + + def is_beginning_of_word(i): + if i < dictionary.nspecial: + # special elements are always considered beginnings + return True + tok = dictionary[i] + if tok.startswith("madeupword"): + return True + try: + return bpe.is_beginning_of_word(tok) + except ValueError: + return True + + mask_whole_words = torch.ByteTensor( + list(map(is_beginning_of_word, range(len(dictionary)))) + ) + return mask_whole_words + return None diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py new file mode 100644 index 0000000..ed08c1b --- /dev/null +++ b/fairseq/data/fairseq_dataset.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch.utils.data +from fairseq.data import data_utils + + +class EpochListening: + """Mixin for receiving updates whenever the epoch increments.""" + + @property + def can_reuse_epoch_itr_across_epochs(self): + """ + Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for + this dataset across epochs. + + This needs to return ``False`` if the sample sizes can change across + epochs, in which case we may need to regenerate batches at each epoch. + If your dataset relies in ``set_epoch`` then you should consider setting + this to ``False``. + """ + return True + + def set_epoch(self, epoch): + """Will receive the updated epoch number at the beginning of the epoch.""" + pass + + +class FairseqDataset(torch.utils.data.Dataset, EpochListening): + """A dataset that provides helpers for batching.""" + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch suitable for forwarding with a Model + """ + raise NotImplementedError + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + raise NotImplementedError + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + raise NotImplementedError + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + return np.arange(len(self), dtype=np.int64) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return False + + def attr(self, attr: str, index: int): + return getattr(self, attr, None) + + def prefetch(self, indices): + """Prefetch the data required for this epoch.""" + raise NotImplementedError + + def get_batch_shapes(self): + """ + Return a list of valid batch shapes, for example:: + + [(8, 512), (16, 256), (32, 128)] + + The first dimension of each tuple is the batch size and can be ``None`` + to automatically infer the max batch size based on ``--max-tokens``. + The second dimension of each tuple is the max supported length as given + by :func:`fairseq.data.FairseqDataset.num_tokens`. + + This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` + to restrict batch shapes. This is useful on TPUs to avoid too many + dynamic shapes (and recompilations). + """ + return None + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + """ + Given an ordered set of indices, return batches according to + *max_tokens*, *max_sentences* and *required_batch_size_multiple*. + """ + from fairseq.data import data_utils + + fixed_shapes = self.get_batch_shapes() + if fixed_shapes is not None: + + def adjust_bsz(bsz, num_tokens): + if bsz is None: + assert max_tokens is not None, "Must specify --max-tokens" + bsz = max_tokens // num_tokens + if max_sentences is not None: + bsz = min(bsz, max_sentences) + elif ( + bsz >= required_batch_size_multiple + and bsz % required_batch_size_multiple != 0 + ): + bsz -= bsz % required_batch_size_multiple + return bsz + + fixed_shapes = np.array( + [ + [adjust_bsz(bsz, num_tokens), num_tokens] + for (bsz, num_tokens) in fixed_shapes + ] + ) + + return data_utils.batch_by_size( + indices, + num_tokens_fn=self.num_tokens, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + fixed_shapes=fixed_shapes, + ) + + def filter_indices_by_size(self, indices, max_sizes): + """ + Filter a list of sample indices. Remove those that are longer than + specified in *max_sizes*. + + WARNING: don't update, override method in child classes + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if isinstance(max_sizes, float) or isinstance(max_sizes, int): + if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): + ignored = indices[self.sizes[indices] > max_sizes].tolist() + indices = indices[self.sizes[indices] <= max_sizes] + elif ( + hasattr(self, "sizes") + and isinstance(self.sizes, list) + and len(self.sizes) == 1 + ): + ignored = indices[self.sizes[0][indices] > max_sizes].tolist() + indices = indices[self.sizes[0][indices] <= max_sizes] + else: + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + else: + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + return indices, ignored + + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return True + + +class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): + """ + For datasets that need to be read sequentially, usually because the data is + being streamed or otherwise can't be manipulated on a single machine. + """ + + def __iter__(self): + raise NotImplementedError diff --git a/fairseq/data/fasta_dataset.py b/fairseq/data/fasta_dataset.py new file mode 100644 index 0000000..0070119 --- /dev/null +++ b/fairseq/data/fasta_dataset.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import subprocess +import threading +from pathlib import Path + +import numpy as np +import torch + + +def fasta_file_path(prefix_path): + return prefix_path + ".fasta" + + +class FastaDataset(torch.utils.data.Dataset): + """ + For loading protein sequence datasets in the common FASTA data format + """ + + def __init__(self, path: str, cache_indices=False): + self.fn = fasta_file_path(path) + self.threadlocal = threading.local() + self.cache = Path(f"{path}.fasta.idx.npy") + if cache_indices: + if self.cache.exists(): + self.offsets, self.sizes = np.load(self.cache) + else: + self.offsets, self.sizes = self._build_index(path) + np.save(self.cache, np.stack([self.offsets, self.sizes])) + else: + self.offsets, self.sizes = self._build_index(path) + + def _get_file(self): + if not hasattr(self.threadlocal, "f"): + self.threadlocal.f = open(self.fn, "r") + return self.threadlocal.f + + def __getitem__(self, idx): + f = self._get_file() + f.seek(self.offsets[idx]) + desc = f.readline().strip() + line = f.readline() + seq = "" + while line != "" and line[0] != ">": + seq += line.strip() + line = f.readline() + return desc, seq + + def __len__(self): + return self.offsets.size + + def _build_index(self, path: str): + # Use grep and awk to get 100M/s on local SSD. + # Should process your enormous 100G fasta in ~10 min single core... + path = fasta_file_path(path) + bytes_offsets = subprocess.check_output( + f"cat {path} | tqdm --bytes --total $(wc -c < {path})" + "| grep --byte-offset '^>' -o | cut -d: -f1", + shell=True, + ) + fasta_lengths = subprocess.check_output( + f"cat {path} | tqdm --bytes --total $(wc -c < {path})" + "| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'", + shell=True, + ) + bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ") + sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ") + return bytes_np, sizes_np + + def __setstate__(self, state): + self.__dict__ = state + self.threadlocal = threading.local() + + def __getstate__(self): + d = {} + for i, v in self.__dict__.items(): + if i != "threadlocal": + d[i] = v + return d + + def __del__(self): + if hasattr(self.threadlocal, "f"): + self.threadlocal.f.close() + del self.threadlocal.f + + @staticmethod + def exists(path): + return os.path.exists(fasta_file_path(path)) + + +class EncodedFastaDataset(FastaDataset): + """ + The FastaDataset returns raw sequences - this allows us to return + indices with a dictionary instead. + """ + + def __init__(self, path, dictionary): + super().__init__(path, cache_indices=True) + self.dictionary = dictionary + + def __getitem__(self, idx): + desc, seq = super().__getitem__(idx) + return self.dictionary.encode_line(seq, line_tokenizer=list).long() diff --git a/fairseq/data/id_dataset.py b/fairseq/data/id_dataset.py new file mode 100644 index 0000000..3e4d796 --- /dev/null +++ b/fairseq/data/id_dataset.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import FairseqDataset + + +class IdDataset(FairseqDataset): + def __getitem__(self, index): + return index + + def __len__(self): + return 0 + + def collater(self, samples): + return torch.tensor(samples) diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py new file mode 100644 index 0000000..827754d --- /dev/null +++ b/fairseq/data/indexed_dataset.py @@ -0,0 +1,561 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import shutil +import struct +from functools import lru_cache + +import numpy as np +import torch +from fairseq.dataclass.constants import DATASET_IMPL_CHOICES +from fairseq.data.fasta_dataset import FastaDataset +from fairseq.file_io import PathManager + +from . import FairseqDataset + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def get_available_dataset_impl(): + return list(map(str, DATASET_IMPL_CHOICES)) + + +def infer_dataset_impl(path): + if IndexedRawTextDataset.exists(path): + return "raw" + elif IndexedDataset.exists(path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return "cached" + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return "mmap" + else: + return None + elif FastaDataset.exists(path): + return "fasta" + else: + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == "mmap": + return MMapIndexedDatasetBuilder( + out_file, dtype=__best_fitting_dtype(vocab_size) + ) + elif impl == "fasta": + raise NotImplementedError + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None): + if impl == "raw" and IndexedRawTextDataset.exists(path): + assert dictionary is not None + return IndexedRawTextDataset(path, dictionary) + elif impl == "lazy" and IndexedDataset.exists(path): + return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing) + elif impl == "cached" and IndexedDataset.exists(path): + return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing) + elif impl == "mmap" and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path) + elif impl == "fasta" and FastaDataset.exists(path): + from fairseq.data.fasta_dataset import EncodedFastaDataset + + return EncodedFastaDataset(path, dictionary) + return None + + +def dataset_exists(path, impl): + if impl == "raw": + return IndexedRawTextDataset.exists(path) + elif impl == "mmap": + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float, + 7: np.double, + 8: np.uint16, +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + ".idx" + + +def data_file_path(prefix_path): + return prefix_path + ".bin" + + +class IndexedDataset(FairseqDataset): + """Loader for TorchNet IndexedDataset""" + + _HDR_MAGIC = b"TNTIDX\x00\x00" + + def __init__(self, path, fix_lua_indexing=False): + super().__init__() + self.path = path + self.fix_lua_indexing = fix_lua_indexing + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." + ) + version = f.read(8) + assert struct.unpack("= self._len: + raise IndexError("index out of range") + + def __del__(self): + if self.data_file: + self.data_file.close() + + @lru_cache(maxsize=8) + def __getitem__(self, i): + if not self.data_file: + self.read_data(self.path) + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + item = torch.from_numpy(a).long() + if self.fix_lua_indexing: + item -= 1 # subtract 1 for 0-based indexing + return item + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return PathManager.exists(index_file_path(path)) and PathManager.exists( + data_file_path(path) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + def __init__(self, path, fix_lua_indexing=False): + super().__init__(path, fix_lua_indexing=fix_lua_indexing) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx : ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + @lru_cache(maxsize=8) + def __getitem__(self, i): + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx : ptx + a.size]) + item = torch.from_numpy(a).long() + if self.fix_lua_indexing: + item -= 1 # subtract 1 for 0-based indexing + return item + + +class IndexedRawTextDataset(FairseqDataset): + """Takes a text file as input and binarizes it in memory at instantiation. + Original lines are also kept in memory""" + + def __init__(self, path, dictionary, append_eos=True, reverse_order=False): + self.tokens_list = [] + self.lines = [] + self.sizes = [] + self.append_eos = append_eos + self.reverse_order = reverse_order + self.read_data(path, dictionary) + self.size = len(self.tokens_list) + + def read_data(self, path, dictionary): + with open(path, "r", encoding="utf-8") as f: + for line in f: + self.lines.append(line.strip("\n")) + tokens = dictionary.encode_line( + line, + add_if_not_exist=False, + append_eos=self.append_eos, + reverse_order=self.reverse_order, + ).long() + self.tokens_list.append(tokens) + self.sizes.append(len(tokens)) + self.sizes = np.array(self.sizes) + + def check_index(self, i): + if i < 0 or i >= self.size: + raise IndexError("index out of range") + + @lru_cache(maxsize=8) + def __getitem__(self, i): + self.check_index(i) + return self.tokens_list[i] + + def get_original_text(self, i): + self.check_index(i) + return self.lines[i] + + def __del__(self): + pass + + def __len__(self): + return self.size + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return PathManager.exists(path) + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float: 4, + np.double: 8, + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, "wb") + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + + def add_item(self, tensor): + # +1 for Lua compatibility + bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), "rb") as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, "wb") + index.write(b"TNTIDX\x00\x00") + index.write(struct.pack("= self.total: + raise RuntimeError( + "Mismatch between actual and expected iterable length. " + "This may be caused by resuming training from a checkpoint using " + "a different number of GPUs, in which case you can try the " + "--reset-dataloader option. Alternatively you may have a train or " + "validation set that is smaller than the number of GPUs. If none " + "of these apply, please report this to the fairseq developers." + ) + self.n += 1 + yield x + + def __next__(self): + return next(self.itr) + + def has_next(self): + """Whether the iterator has been exhausted.""" + return self.n < len(self) + + def skip(self, num_to_skip): + """Fast-forward the iterator by skipping *num_to_skip* elements.""" + next(itertools.islice(self.itr, num_to_skip, num_to_skip), None) + return self + + def take(self, n): + """ + Truncates the iterator to n elements at most. + """ + self.total = min(self.total, n) + + # Propagate this change to the underlying iterator + # Only take after what we have already consumed (i.e. after restarting + # from checkpoint mid epoch, we have to subtract self.n which is the + # starting point) + # + # This to maintain the invariant self.total = self.n + len(iterable), + # before calling __next__ or __iter__ + propagated_take = max(n - self.n, 0) + if hasattr(self.iterable, "take"): + self.iterable.take(propagated_take) + else: + self.iterable = itertools.islice(self.iterable, propagated_take) + + +class EpochBatchIterating(object): + def __len__(self) -> int: + raise NotImplementedError + + @property + def next_epoch_idx(self): + raise NotImplementedError + + def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): + """Return a new iterator over the dataset. + + Args: + shuffle (bool, optional): shuffle batches before returning the + iterator (default: True). + fix_batches_to_gpus: ensure that batches are always + allocated to the same shards across epochs. Requires + that :attr:`dataset` supports prefetching (default: False). + """ + raise NotImplementedError + + def end_of_epoch(self) -> bool: + """Returns whether the most recent epoch iterator has been exhausted""" + raise NotImplementedError + + @property + def iterations_in_epoch(self) -> int: + """The number of consumed batches in the current epoch.""" + raise NotImplementedError + + def state_dict(self): + """Returns a dictionary containing a whole state of the iterator.""" + raise NotImplementedError + + def load_state_dict(self, state_dict): + """Copies the state of the iterator from the given *state_dict*.""" + raise NotImplementedError + + +class StreamingEpochBatchIterator(EpochBatchIterating): + def __init__( + self, + dataset, + epoch=1, + num_shards=1, + shard_id=0, + ): + assert isinstance(dataset, torch.utils.data.IterableDataset) + self.dataset = dataset + self.epoch = max(epoch, 1) # we use 1-based indexing for epochs + self._current_epoch_iterator = None + self.num_shards = num_shards + self.shard_id = shard_id + + @property + def next_epoch_idx(self): + """Return the epoch index after *next_epoch_itr* is called.""" + if self._current_epoch_iterator is not None and self.end_of_epoch(): + return self.epoch + 1 + else: + return self.epoch + + def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): + self.epoch = self.next_epoch_idx + self.dataset.set_epoch(self.epoch) + self._current_epoch_iterator = CountingIterator( + iterable=ShardedIterator( + iterable=self.dataset, + num_shards=self.num_shards, + shard_id=self.shard_id, + ), + ) + return self._current_epoch_iterator + + def end_of_epoch(self) -> bool: + return not self._current_epoch_iterator.has_next() + + @property + def iterations_in_epoch(self) -> int: + if self._current_epoch_iterator is not None: + return self._current_epoch_iterator.n + return 0 + + def state_dict(self): + return { + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict["epoch"] + + +class EpochBatchIterator(EpochBatchIterating): + """A multi-epoch iterator over a :class:`torch.utils.data.Dataset`. + + Compared to :class:`torch.utils.data.DataLoader`, this iterator: + + - can be reused across multiple epochs with the :func:`next_epoch_itr` + method (optionally shuffled between epochs) + - can be serialized/deserialized with the :func:`state_dict` and + :func:`load_state_dict` methods + - supports sharding with the *num_shards* and *shard_id* arguments + + Args: + dataset (~torch.utils.data.Dataset): dataset from which to load the data + collate_fn (callable): merges a list of samples to form a mini-batch + batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of + indices, or a callable to create such an iterator (~torch.utils.data.Sampler). + A callable batch_sampler will be called for each epoch to enable per epoch dynamic + batch iterators defined by this callable batch_sampler. + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + buffer_size (int, optional): the number of batches to keep ready in the + queue. Helps speeding up dataloading. When buffer_size is zero, the + default torch.utils.data.DataLoader preloading is used. + timeout (int, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: ``0``) + """ + + def __init__( + self, + dataset, + collate_fn, + batch_sampler, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + buffer_size=0, + timeout=0, + ): + assert isinstance(dataset, torch.utils.data.Dataset) + self.dataset = dataset + self.collate_fn = collate_fn + self.batch_sampler = batch_sampler + self._frozen_batches = ( + tuple(batch_sampler) if not callable(batch_sampler) else None + ) + self.seed = seed + self.num_shards = num_shards + self.shard_id = shard_id + self.num_workers = num_workers + # This upper limit here is to prevent people from abusing this feature + # in a shared computing environment. + self.buffer_size = min(buffer_size, 20) + self.timeout = timeout + + self.epoch = max(epoch, 1) # we use 1-based indexing for epochs + self.shuffle = True + self._cur_epoch_itr = None + self._next_epoch_itr = None + self._supports_prefetch = getattr(dataset, "supports_prefetch", False) + + @property + def frozen_batches(self): + if self._frozen_batches is None: + self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch)) + return self._frozen_batches + + @property + def first_batch(self): + if len(self.frozen_batches) == 0: + raise Exception( + "The dataset is empty. This could indicate " + "that all elements in the dataset have been skipped. " + "Try increasing the max number of allowed tokens or using " + "a larger dataset." + ) + + if self.dataset.supports_fetch_outside_dataloader: + return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]]) + else: + return "DUMMY" + + def __len__(self): + return int(math.ceil(len(self.frozen_batches) / float(self.num_shards))) + + @property + def n(self): + return self.iterations_in_epoch + + @property + def next_epoch_idx(self): + """Return the epoch index after *next_epoch_itr* is called.""" + if self._next_epoch_itr is not None: + return self.epoch + elif self._cur_epoch_itr is not None and self.end_of_epoch(): + return self.epoch + 1 + else: + return self.epoch + + def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False): + """Return a new iterator over the dataset. + + Args: + shuffle (bool, optional): shuffle batches before returning the + iterator (default: True). + fix_batches_to_gpus: ensure that batches are always + allocated to the same shards across epochs. Requires + that :attr:`dataset` supports prefetching (default: False). + """ + self.epoch = self.next_epoch_idx + self.dataset.set_epoch(self.epoch) + if self._next_epoch_itr is not None: + self._cur_epoch_itr = self._next_epoch_itr + self._next_epoch_itr = None + else: + if callable(self.batch_sampler): + # reset _frozen_batches to refresh the next epoch + self._frozen_batches = None + self._cur_epoch_itr = self._get_iterator_for_epoch( + self.epoch, + shuffle, + fix_batches_to_gpus=fix_batches_to_gpus, + ) + self.shuffle = shuffle + return self._cur_epoch_itr + + def end_of_epoch(self) -> bool: + """Returns whether the most recent epoch iterator has been exhausted""" + return not self._cur_epoch_itr.has_next() + + @property + def iterations_in_epoch(self): + """The number of consumed batches in the current epoch.""" + if self._cur_epoch_itr is not None: + return self._cur_epoch_itr.n + elif self._next_epoch_itr is not None: + return self._next_epoch_itr.n + return 0 + + def state_dict(self): + """Returns a dictionary containing a whole state of the iterator.""" + if self.end_of_epoch(): + epoch = self.epoch + 1 + iter_in_epoch = 0 + else: + epoch = self.epoch + iter_in_epoch = self.iterations_in_epoch + return { + "version": 2, + "epoch": epoch, + "iterations_in_epoch": iter_in_epoch, + "shuffle": self.shuffle, + } + + def load_state_dict(self, state_dict): + """Copies the state of the iterator from the given *state_dict*.""" + self.epoch = state_dict["epoch"] + itr_pos = state_dict.get("iterations_in_epoch", 0) + version = state_dict.get("version", 1) + if itr_pos > 0: + # fast-forward epoch iterator + self._next_epoch_itr = self._get_iterator_for_epoch( + self.epoch, + shuffle=state_dict.get("shuffle", True), + offset=itr_pos, + ) + if self._next_epoch_itr is None: + if version == 1: + # legacy behavior: we finished the epoch, increment epoch counter + self.epoch += 1 + else: + raise RuntimeError( + "Cannot resume training due to dataloader mismatch, please " + "report this to the fairseq developers. You can relaunch " + "training with `--reset-dataloader` and it should work." + ) + else: + self._next_epoch_itr = None + + def _get_iterator_for_epoch( + self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 + ): + def shuffle_batches(batches, seed): + with data_utils.numpy_seed(seed): + np.random.shuffle(batches) + return batches + + if self._supports_prefetch: + batches = self.frozen_batches + + if shuffle and not fix_batches_to_gpus: + batches = shuffle_batches(list(batches), self.seed + epoch) + + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + self.dataset.prefetch([i for s in batches for i in s]) + + if shuffle and fix_batches_to_gpus: + batches = shuffle_batches(batches, self.seed + epoch + self.shard_id) + else: + if shuffle: + batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch) + else: + batches = self.frozen_batches + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + + if offset > 0 and offset >= len(batches): + return None + + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + # Create data loader + itr = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_sampler=batches[offset:], + num_workers=self.num_workers, + timeout=self.timeout, + ) + + # Wrap with a BufferedIterator if needed + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + # Wrap with CoutingIterator + itr = CountingIterator(itr, start=offset) + return itr + + +class GroupedIterator(CountingIterator): + """Wrapper around an iterable that returns groups (chunks) of items. + + Args: + iterable (iterable): iterable to wrap + chunk_size (int): size of each chunk + + Attributes: + n (int): number of elements consumed from this iterator + """ + + def __init__(self, iterable, chunk_size): + itr = _chunk_iterator(iterable, chunk_size) + super().__init__( + itr, + start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))), + total=int(math.ceil(len(iterable) / float(chunk_size))), + ) + self.chunk_size = chunk_size + + +def _chunk_iterator(itr, chunk_size): + chunk = [] + for x in itr: + chunk.append(x) + if len(chunk) == chunk_size: + yield chunk + chunk = [] + if len(chunk) > 0: + yield chunk + + +class ShardedIterator(CountingIterator): + """A sharded wrapper around an iterable, padded to length. + + Args: + iterable (iterable): iterable to wrap + num_shards (int): number of shards to split the iterable into + shard_id (int): which shard to iterator over + fill_value (Any, optional): padding value when the iterable doesn't + evenly divide *num_shards* (default: None). + + Attributes: + n (int): number of elements consumed from this iterator + """ + + def __init__(self, iterable, num_shards, shard_id, fill_value=None): + if shard_id < 0 or shard_id >= num_shards: + raise ValueError("shard_id must be between 0 and num_shards") + sharded_len = int(math.ceil(len(iterable) / float(num_shards))) + itr = map( + operator.itemgetter(1), + itertools.zip_longest( + range(sharded_len), + itertools.islice(iterable, shard_id, len(iterable), num_shards), + fillvalue=fill_value, + ), + ) + super().__init__( + itr, + start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))), + total=sharded_len, + ) + + +class BackgroundConsumer(Thread): + def __init__(self, queue, source, max_len): + Thread.__init__(self) + + self._queue = queue + self._source = source + self._max_len = max_len + self.count = 0 + + def run(self): + try: + for item in self._source: + self._queue.put(item) + + # Stop if we reached the maximum length + self.count += 1 + if self._max_len is not None and self.count >= self._max_len: + break + + # Signal the consumer we are done. + self._queue.put(_sentinel) + except Exception as e: + self._queue.put(e) + + +class BufferedIterator(object): + def __init__(self, size, iterable): + self._queue = queue.Queue(size) + self._iterable = iterable + self._consumer = None + + self.start_time = time.time() + self.warning_time = None + + self.total = len(iterable) + + def _create_consumer(self): + self._consumer = BackgroundConsumer( + self._queue, + self._iterable, + self.total, + ) + self._consumer.daemon = True + self._consumer.start() + + def __iter__(self): + return self + + def __len__(self): + return self.total + + def take(self, n): + self.total = min(self.total, n) + + # Propagate this change to the underlying iterator + if hasattr(self._iterable, "take"): + self._iterable.take(n) + + def __next__(self): + # Create consumer if not created yet + if self._consumer is None: + self._create_consumer() + + # Notify the user if there is a data loading bottleneck + if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): + if time.time() - self.start_time > 5 * 60: + if ( + self.warning_time is None + or time.time() - self.warning_time > 15 * 60 + ): + logger.debug( + "Data loading buffer is empty or nearly empty. This may " + "indicate a data loading bottleneck, and increasing the " + "number of workers (--num-workers) may help." + ) + self.warning_time = time.time() + + # Get next example + item = self._queue.get(True) + if isinstance(item, Exception): + raise item + if item is _sentinel: + raise StopIteration() + return item diff --git a/fairseq/data/legacy/__init__.py b/fairseq/data/legacy/__init__.py new file mode 100644 index 0000000..9bd5c72 --- /dev/null +++ b/fairseq/data/legacy/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .block_pair_dataset import BlockPairDataset +from .masked_lm_dataset import MaskedLMDataset +from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary + + +__all__ = [ + "BertDictionary", + "BlockPairDataset", + "MaskedLMDataset", + "MaskedLMDictionary", +] diff --git a/fairseq/data/legacy/block_pair_dataset.py b/fairseq/data/legacy/block_pair_dataset.py new file mode 100644 index 0000000..ba069b4 --- /dev/null +++ b/fairseq/data/legacy/block_pair_dataset.py @@ -0,0 +1,311 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import numpy as np +import torch +from fairseq.data import FairseqDataset + + +class BlockPairDataset(FairseqDataset): + """Break a Dataset of tokens into sentence pair blocks for next sentence + prediction as well as masked language model. + + High-level logics are: + 1. break input tensor to tensor blocks + 2. pair the blocks with 50% next sentence and 50% random sentence + 3. return paired blocks as well as related segment labels + + Args: + dataset (~torch.utils.data.Dataset): dataset to break into blocks + sizes: array of sentence lengths + dictionary: dictionary for the task + block_size: maximum block size + break_mode: mode for breaking copurs into block pairs. currently we support + 2 modes + doc: respect document boundaries and each part of the pair should belong to on document + none: don't respect any boundary and cut tokens evenly + short_seq_prob: probability for generating shorter block pairs + doc_break_size: Size for empty line separating documents. Typically 1 if + the sentences have eos, 0 otherwise. + """ + + def __init__( + self, + dataset, + dictionary, + sizes, + block_size, + break_mode="doc", + short_seq_prob=0.1, + doc_break_size=1, + ): + super().__init__() + self.dataset = dataset + self.pad = dictionary.pad() + self.eos = dictionary.eos() + self.cls = dictionary.cls() + self.mask = dictionary.mask() + self.sep = dictionary.sep() + self.break_mode = break_mode + self.dictionary = dictionary + self.short_seq_prob = short_seq_prob + self.block_indices = [] + + assert len(dataset) == len(sizes) + + if break_mode == "doc": + cur_doc = [] + for sent_id, sz in enumerate(sizes): + assert doc_break_size == 0 or sz != 0, ( + "when doc_break_size is non-zero, we expect documents to be" + "separated by a blank line with a single eos." + ) + # empty line as document separator + if sz == doc_break_size: + if len(cur_doc) == 0: + continue + self.block_indices.append(cur_doc) + cur_doc = [] + else: + cur_doc.append(sent_id) + max_num_tokens = block_size - 3 # Account for [CLS], [SEP], [SEP] + self.sent_pairs = [] + self.sizes = [] + for doc_id, doc in enumerate(self.block_indices): + self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes) + elif break_mode is None or break_mode == "none": + # each block should have half of the block size since we are constructing block pair + sent_length = (block_size - 3) // 2 + total_len = sum(dataset.sizes) + length = math.ceil(total_len / sent_length) + + def block_at(i): + start = i * sent_length + end = min(start + sent_length, total_len) + return (start, end) + + sent_indices = np.array([block_at(i) for i in range(length)]) + sent_sizes = np.array([e - s for s, e in sent_indices]) + dataset_index = self._sent_to_dataset_index(sent_sizes) + + # pair sentences + self._pair_sentences(dataset_index) + else: + raise ValueError("Invalid break_mode: " + break_mode) + + def _pair_sentences(self, dataset_index): + """ + Give a list of evenly cut blocks/sentences, pair these sentences with 50% + consecutive sentences and 50% random sentences. + This is used for none break mode + """ + # pair sentences + for sent_id, sent in enumerate(dataset_index): + next_sent_label = ( + 1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0 + ) + if next_sent_label: + next_sent = dataset_index[sent_id + 1] + else: + next_sent = dataset_index[ + self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1]) + ] + self.sent_pairs.append((sent, next_sent, next_sent_label)) + + # The current blocks don't include the special tokens but the + # sizes already account for this + self.sizes.append(3 + sent[3] + next_sent[3]) + + def _sent_to_dataset_index(self, sent_sizes): + """ + Build index mapping block indices to the underlying dataset indices + """ + dataset_index = [] + ds_idx, ds_remaining = -1, 0 + for to_consume in sent_sizes: + sent_size = to_consume + if ds_remaining == 0: + ds_idx += 1 + ds_remaining = sent_sizes[ds_idx] + start_ds_idx = ds_idx + start_offset = sent_sizes[ds_idx] - ds_remaining + while to_consume > ds_remaining: + to_consume -= ds_remaining + ds_idx += 1 + ds_remaining = sent_sizes[ds_idx] + ds_remaining -= to_consume + dataset_index.append( + ( + start_ds_idx, # starting index in dataset + start_offset, # starting offset within starting index + ds_idx, # ending index in dataset + sent_size, # sentence length + ) + ) + assert ds_remaining == 0 + assert ds_idx == len(self.dataset) - 1 + return dataset_index + + def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes): + """ + Go through a single document and genrate sentence paris from it + """ + current_chunk = [] + current_length = 0 + curr = 0 + # To provide more randomness, we decrease target seq length for parts of + # samples (10% by default). Note that max_num_tokens is the hard threshold + # for batching and will never be changed. + target_seq_length = max_num_tokens + if np.random.random() < self.short_seq_prob: + target_seq_length = np.random.randint(2, max_num_tokens) + # loop through all sentences in document + while curr < len(doc): + sent_id = doc[curr] + current_chunk.append(sent_id) + current_length = sum(sizes[current_chunk]) + # split chunk and generate pair when exceed target_seq_length or + # finish the loop + if curr == len(doc) - 1 or current_length >= target_seq_length: + # split the chunk into 2 parts + a_end = 1 + if len(current_chunk) > 2: + a_end = np.random.randint(1, len(current_chunk) - 1) + sent_a = current_chunk[:a_end] + len_a = sum(sizes[sent_a]) + # generate next sentence label, note that if there is only 1 sentence + # in current chunk, label is always 0 + next_sent_label = ( + 1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0 + ) + if not next_sent_label: + # if next sentence label is 0, sample sent_b from a random doc + target_b_length = target_seq_length - len_a + rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id]) + random_doc = self.block_indices[rand_doc_id] + random_start = np.random.randint(0, len(random_doc)) + sent_b = [] + len_b = 0 + for j in range(random_start, len(random_doc)): + sent_b.append(random_doc[j]) + len_b = sum(sizes[sent_b]) + if len_b >= target_b_length: + break + # return the second part of the chunk since it's not used + num_unused_segments = len(current_chunk) - a_end + curr -= num_unused_segments + else: + # if next sentence label is 1, use the second part of chunk as sent_B + sent_b = current_chunk[a_end:] + len_b = sum(sizes[sent_b]) + # currently sent_a and sent_B may be longer than max_num_tokens, + # truncate them and return block idx and offsets for them + sent_a, sent_b = self._truncate_sentences( + sent_a, sent_b, max_num_tokens + ) + self.sent_pairs.append((sent_a, sent_b, next_sent_label)) + self.sizes.append(3 + sent_a[3] + sent_b[3]) + current_chunk = [] + curr += 1 + + def _skip_sampling(self, total, skip_ids): + """ + Generate a random integer which is not in skip_ids. Sample range is [0, total) + TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later + """ + rand_id = np.random.randint(total - len(skip_ids)) + return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids) + + def _truncate_sentences(self, sent_a, sent_b, max_num_tokens): + """ + Trancate a pair of sentence to limit total length under max_num_tokens + Logics: + 1. Truncate longer sentence + 2. Tokens to be truncated could be at the beginning or the end of the sentnce + Returns: + Truncated sentences represented by dataset idx + """ + len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b]) + front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0 + + while True: + total_length = ( + len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b + ) + if total_length <= max_num_tokens: + break + + if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b: + if np.random.rand() < 0.5: + front_cut_a += 1 + else: + end_cut_a += 1 + else: + if np.random.rand() < 0.5: + front_cut_b += 1 + else: + end_cut_b += 1 + + # calculate ds indices as well as offsets and return + truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a) + truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b) + return truncated_sent_a, truncated_sent_b + + def _cut_sentence(self, sent, front_cut, end_cut): + """ + Cut a sentence based on the numbers of tokens to be cut from beginning and end + Represent the sentence as dataset idx and return + """ + start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0 + target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut + while front_cut > 0: + if self.dataset.sizes[start_ds_idx] > front_cut: + offset += front_cut + break + else: + front_cut -= self.dataset.sizes[start_ds_idx] + start_ds_idx += 1 + while end_cut > 0: + if self.dataset.sizes[end_ds_idx] > end_cut: + break + else: + end_cut -= self.dataset.sizes[end_ds_idx] + end_ds_idx -= 1 + return start_ds_idx, offset, end_ds_idx, target_len + + def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length): + """ + Fetch a block of tokens based on its dataset idx + """ + buffer = torch.cat( + [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] + ) + s, e = offset, offset + length + return buffer[s:e] + + def __getitem__(self, index): + block1, block2, next_sent_label = self.sent_pairs[index] + block1 = self._fetch_block(*block1) + block2 = self._fetch_block(*block2) + return block1, block2, next_sent_label + + def __len__(self): + return len(self.sizes) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + prefetch_idx = set() + for index in indices: + for block1, block2, _ in [self.sent_pairs[index]]: + for ds_idx in range(block1[0], block1[2] + 1): + prefetch_idx.add(ds_idx) + for ds_idx in range(block2[0], block2[2] + 1): + prefetch_idx.add(ds_idx) + self.dataset.prefetch(prefetch_idx) diff --git a/fairseq/data/legacy/masked_lm_dataset.py b/fairseq/data/legacy/masked_lm_dataset.py new file mode 100644 index 0000000..dd8ea2c --- /dev/null +++ b/fairseq/data/legacy/masked_lm_dataset.py @@ -0,0 +1,303 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, List, Tuple + +import numpy as np +import torch +from fairseq.data import Dictionary, FairseqDataset, data_utils +from fairseq.data.concat_dataset import ConcatDataset +from fairseq.data.legacy.block_pair_dataset import BlockPairDataset +from fairseq.data.token_block_dataset import TokenBlockDataset + + +class MaskedLMDataset(FairseqDataset): + """ + A wrapper Dataset for masked language modelling. The dataset + wraps around TokenBlockDataset or BlockedPairDataset and creates a batch + where the input blocks are masked according to the specified masking + probability. Additionally the batch can also contain sentence level targets + if this is specified. + + Args: + dataset: Dataset which generates blocks of data. Only BlockPairDataset + and TokenBlockDataset are supported. + sizes: Sentence lengths + vocab: Dictionary with the vocabulary and special tokens. + pad_idx: Id of padding token in dictionary + mask_idx: Id of mask token in dictionary + classif_token_idx: Id of classification token in dictionary. This is the + token associated with the sentence embedding (Eg: CLS for BERT) + sep_token_idx: Id of separator token in dictionary + (Eg: SEP in BERT) + seed: Seed for random number generator for reproducibility. + shuffle: Shuffle the elements before batching. + has_pairs: Specifies whether the underlying dataset + generates a pair of blocks along with a sentence_target or not. + Setting it to True assumes that the underlying dataset generates a + label for the pair of sentences which is surfaced as + sentence_target. The default value assumes a single block with no + sentence target. + segment_id: An optional segment id for filling in the segment labels + when we are in the single block setting (Eg: XLM). Default is 0. + masking_ratio: specifies what percentage of the blocks should be masked. + masking_prob: specifies the probability of a given token being + replaced with the "MASK" token. + random_token_prob: specifies the probability of a given token being + replaced by a random token from the vocabulary. + """ + + def __init__( + self, + dataset: FairseqDataset, + sizes: np.ndarray, + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + classif_token_idx: int, + sep_token_idx: int, + seed: int = 1, + shuffle: bool = True, + has_pairs: bool = True, + segment_id: int = 0, + masking_ratio: float = 0.15, + masking_prob: float = 0.8, + random_token_prob: float = 0.1, + ): + # Make sure the input datasets are the ones supported + assert ( + isinstance(dataset, TokenBlockDataset) + or isinstance(dataset, BlockPairDataset) + or isinstance(dataset, ConcatDataset) + ), ( + "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or " + "ConcatDataset" + ) + + self.dataset = dataset + self.sizes = np.array(sizes) + self.vocab = vocab + self.pad_idx = pad_idx + self.mask_idx = mask_idx + self.classif_token_idx = classif_token_idx + self.sep_token_idx = sep_token_idx + self.shuffle = shuffle + self.seed = seed + self.has_pairs = has_pairs + self.segment_id = segment_id + self.masking_ratio = masking_ratio + self.masking_prob = masking_prob + self.random_token_prob = random_token_prob + + # If we have only one block then sizes needs to be updated to include + # the classification token + if not has_pairs: + self.sizes = self.sizes + 1 + + def __getitem__(self, index: int): + # if has_pairs, then expect 2 blocks and a sentence target + if self.has_pairs: + (block_one, block_two, sentence_target) = self.dataset[index] + else: + block_one = self.dataset[index] + + return { + "id": index, + "block_one": block_one, + "block_two": block_two if self.has_pairs else None, + "sentence_target": sentence_target if self.has_pairs else None, + } + + def __len__(self): + return len(self.dataset) + + def _mask_block( + self, + sentence: np.ndarray, + mask_idx: int, + pad_idx: int, + dictionary_token_range: Tuple, + ): + """ + Mask tokens for Masked Language Model training + Samples mask_ratio tokens that will be predicted by LM. + + Note:This function may not be efficient enough since we had multiple + conversions between np and torch, we can replace them with torch + operators later. + + Args: + sentence: 1d tensor to be masked + mask_idx: index to use for masking the sentence + pad_idx: index to use for masking the target for tokens we aren't + predicting + dictionary_token_range: range of indices in dictionary which can + be used for random word replacement + (e.g. without special characters) + Return: + masked_sent: masked sentence + target: target with words which we are not predicting replaced + by pad_idx + """ + masked_sent = np.copy(sentence) + sent_length = len(sentence) + mask_num = math.ceil(sent_length * self.masking_ratio) + mask = np.random.choice(sent_length, mask_num, replace=False) + target = np.copy(sentence) + + for i in range(sent_length): + if i in mask: + rand = np.random.random() + + # replace with mask if probability is less than masking_prob + # (Eg: 0.8) + if rand < self.masking_prob: + masked_sent[i] = mask_idx + + # replace with random token if probability is less than + # masking_prob + random_token_prob (Eg: 0.9) + elif rand < (self.masking_prob + self.random_token_prob): + # sample random token from dictionary + masked_sent[i] = np.random.randint( + dictionary_token_range[0], dictionary_token_range[1] + ) + else: + target[i] = pad_idx + + return masked_sent, target + + def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int): + """ + Does the heavy lifting for creating a batch from the input list of + examples. The logic is as follows: + 1. Mask the input blocks. In case has_pair is True then we have 2 + blocks to mask. + 2. Prepend the first masked block tensor with the special token + used as sentence embedding. Eg: CLS in BERT. This happens + irrespective of the value of has_pair. + 3. If has_pair is True, then append the first masked block with the + special separator token (eg: SEP for BERT) and compute segment + label accordingly. In this case, also append the second masked + block with this special separator token and compute its segment + label. + 4. For the targets tensor, prepend and append with padding index + accordingly. + 5. Concatenate all tensors. + """ + if len(samples) == 0: + return {} + # To ensure determinism, we reset the state of the PRNG after every + # batch based on the seed and the first id of the batch. This ensures + # that across epochs we get the same mask for the same example. This + # is needed for reproducibility and is how BERT does masking + # TODO: Can we add deteminism without this constraint? + with data_utils.numpy_seed(self.seed + samples[0]["id"]): + for s in samples: + + # token range is needed for replacing with random token during + # masking + token_range = (self.vocab.nspecial, len(self.vocab)) + + # mask according to specified probabilities. + masked_blk_one, masked_tgt_one = self._mask_block( + s["block_one"], + self.mask_idx, + self.pad_idx, + token_range, + ) + + tokens = np.concatenate([[self.classif_token_idx], masked_blk_one]) + targets = np.concatenate([[self.pad_idx], masked_tgt_one]) + segments = np.ones(len(tokens)) * self.segment_id + + # if has_pairs is True then we need to add the SEP token to both + # the blocks after masking and re-compute segments based on the new + # lengths. + if self.has_pairs: + tokens_one = np.concatenate([tokens, [self.sep_token_idx]]) + targets_one = np.concatenate([targets, [self.pad_idx]]) + + masked_blk_two, masked_tgt_two = self._mask_block( + s["block_two"], self.mask_idx, self.pad_idx, token_range + ) + tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]]) + targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]]) + + # block + 1 sep + 1 special (CLS) + segments_one = np.zeros(len(tokens_one)) + # block + 1 sep + segments_two = np.ones(len(tokens_two)) + + tokens = np.concatenate([tokens_one, tokens_two]) + targets = np.concatenate([targets_one, targets_two]) + segments = np.concatenate([segments_one, segments_two]) + + s["source"] = torch.LongTensor(tokens) + s["segment_labels"] = torch.LongTensor(segments) + s["lm_target"] = torch.LongTensor(targets) + + def merge(key): + return data_utils.collate_tokens( + [s[key] for s in samples], pad_idx, eos_idx, left_pad=False + ) + + return { + "id": torch.LongTensor([s["id"] for s in samples]), + "ntokens": sum(len(s["source"]) for s in samples), + "net_input": { + "src_tokens": merge("source"), + "segment_labels": merge("segment_labels"), + }, + "lm_target": merge("lm_target"), + "sentence_target": torch.LongTensor([s["sentence_target"] for s in samples]) + if self.has_pairs + else None, + "nsentences": len(samples), + } + + def collater(self, samples: List[Dict]): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch of data + """ + return self._collate(samples, self.vocab.pad(), self.vocab.eos()) + + def num_tokens(self, index: int): + """ + Return the number of tokens in a sample. This value is used to + enforce max-tokens during batching. + """ + return self.sizes[index] + + def size(self, index: int): + """ + Return an example's size as a float or tuple. This value is used when + filtering a dataset with max-positions. + """ + return self.sizes[index] + + def ordered_indices(self): + """ + Return an ordered list of indices. Batches will be constructed based + on this order. + """ + if self.shuffle: + return np.random.permutation(len(self)) + else: + order = [np.arange(len(self))] + order.append(self.sizes) + return np.lexsort(order) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + self.dataset.prefetch(indices) diff --git a/fairseq/data/legacy/masked_lm_dictionary.py b/fairseq/data/legacy/masked_lm_dictionary.py new file mode 100644 index 0000000..dee88f7 --- /dev/null +++ b/fairseq/data/legacy/masked_lm_dictionary.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.data import Dictionary + + +class MaskedLMDictionary(Dictionary): + """ + Dictionary for Masked Language Modelling tasks. This extends Dictionary by + adding the mask symbol. + """ + + def __init__( + self, + pad="", + eos="", + unk="", + mask="", + ): + super().__init__(pad=pad, eos=eos, unk=unk) + self.mask_word = mask + self.mask_index = self.add_symbol(mask) + self.nspecial = len(self.symbols) + + def mask(self): + """Helper to get index of mask symbol""" + return self.mask_index + + +class BertDictionary(MaskedLMDictionary): + """ + Dictionary for BERT task. This extends MaskedLMDictionary by adding support + for cls and sep symbols. + """ + + def __init__( + self, + pad="", + eos="", + unk="", + mask="", + cls="", + sep="", + ): + super().__init__(pad=pad, eos=eos, unk=unk, mask=mask) + self.cls_word = cls + self.sep_word = sep + self.cls_index = self.add_symbol(cls) + self.sep_index = self.add_symbol(sep) + self.nspecial = len(self.symbols) + + def cls(self): + """Helper to get index of cls symbol""" + return self.cls_index + + def sep(self): + """Helper to get index of sep symbol""" + return self.sep_index diff --git a/fairseq/data/plasma_utils.py b/fairseq/data/plasma_utils.py new file mode 100644 index 0000000..2b12646 --- /dev/null +++ b/fairseq/data/plasma_utils.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import subprocess +import tempfile + + +class PlasmaArray(object): + """ + Wrapper around numpy arrays that automatically moves the data to shared + memory upon serialization. This is particularly helpful when passing numpy + arrays through multiprocessing, so that data is not unnecessarily + duplicated or pickled. + """ + + def __init__(self, array): + super().__init__() + self.array = array + self.disable = array.nbytes < 134217728 # disable for arrays <128MB + self.object_id = None + self.path = None + + # variables with underscores shouldn't be pickled + self._client = None + self._server = None + self._server_tmp = None + self._plasma = None + + @property + def plasma(self): + if self._plasma is None and not self.disable: + try: + import pyarrow.plasma as plasma + + self._plasma = plasma + except ImportError: + self._plasma = None + return self._plasma + + def start_server(self): + if self.plasma is None or self._server is not None: + return + assert self.object_id is None + assert self.path is None + self._server_tmp = tempfile.NamedTemporaryFile() + self.path = self._server_tmp.name + self._server = subprocess.Popen( + [ + "plasma_store", + "-m", + str(int(1.05 * self.array.nbytes)), + "-s", + self.path, + ] + ) + + @property + def client(self): + if self._client is None: + assert self.path is not None + self._client = self.plasma.connect(self.path) + return self._client + + def __getstate__(self): + if self.plasma is None: + return self.__dict__ + if self.object_id is None: + self.start_server() + self.object_id = self.client.put(self.array) + state = self.__dict__.copy() + del state["array"] + state["_client"] = None + state["_server"] = None + state["_server_tmp"] = None + state["_plasma"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + if self.plasma is None: + return + self.array = self.client.get(self.object_id) + + def __del__(self): + if self._server is not None: + self._server.kill() + self._server = None + self._server_tmp.close() + self._server_tmp = None diff --git a/fairseq/data/resampling_dataset.py b/fairseq/data/resampling_dataset.py new file mode 100644 index 0000000..3d3b993 --- /dev/null +++ b/fairseq/data/resampling_dataset.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +from fairseq.data import BaseWrapperDataset, plasma_utils + + +logger = logging.getLogger(__name__) + + +class ResamplingDataset(BaseWrapperDataset): + """Randomly samples from a given dataset at each epoch. + + Sampling is done with or without replacement, depending on the "replace" + parameter. + + Optionally, the epoch size can be rescaled. This is potentially desirable + to increase per-epoch coverage of the base dataset (since sampling with + replacement means that many items in the dataset will be left out). In the + case of sampling without replacement, size_ratio should be strictly less + than 1. + + Args: + dataset (~torch.utils.data.Dataset): dataset on which to sample. + weights (List[float]): list of probability weights + (default: None, which corresponds to uniform sampling). + replace (bool): sampling mode; True for "with replacement", or False + for "without replacement" (default: True) + size_ratio (float): the ratio to subsample to; must be positive + (default: 1.0). + batch_by_size (bool): whether or not to batch by sequence length + (default: True). + seed (int): RNG seed to use (default: 0). + epoch (int): starting epoch number (default: 1). + """ + + def __init__( + self, + dataset, + weights=None, + replace=True, + size_ratio=1.0, + batch_by_size=True, + seed=0, + epoch=1, + ): + super().__init__(dataset) + + if weights is None: + self.weights = None + + else: + assert len(weights) == len(dataset) + weights_arr = np.array(weights, dtype=np.float64) + weights_arr /= weights_arr.sum() + self.weights = plasma_utils.PlasmaArray(weights_arr) + + self.replace = replace + + assert size_ratio > 0.0 + if not self.replace: + assert size_ratio < 1.0 + self.size_ratio = float(size_ratio) + self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int) + + self.batch_by_size = batch_by_size + self.seed = seed + + self._cur_epoch = None + self._cur_indices = None + + self.set_epoch(epoch) + + def __getitem__(self, index): + return self.dataset[self._cur_indices.array[index]] + + def __len__(self): + return self.actual_size + + @property + def sizes(self): + if isinstance(self.dataset.sizes, list): + return [s[self._cur_indices.array] for s in self.dataset.sizes] + return self.dataset.sizes[self._cur_indices.array] + + def num_tokens(self, index): + return self.dataset.num_tokens(self._cur_indices.array[index]) + + def size(self, index): + return self.dataset.size(self._cur_indices.array[index]) + + def ordered_indices(self): + if self.batch_by_size: + order = [ + np.arange(len(self)), + self.sizes, + ] # No need to handle `self.shuffle == True` + return np.lexsort(order) + else: + return np.arange(len(self)) + + def prefetch(self, indices): + self.dataset.prefetch(self._cur_indices.array[indices]) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + + def set_epoch(self, epoch): + logger.debug("ResamplingDataset.set_epoch: {}".format(epoch)) + super().set_epoch(epoch) + + if epoch == self._cur_epoch: + return + + self._cur_epoch = epoch + + # Generate a weighted sample of indices as a function of the + # random seed and the current epoch. + + rng = np.random.RandomState( + [ + 42, # magic number + self.seed % (2 ** 32), # global seed + self._cur_epoch, # epoch index + ] + ) + self._cur_indices = plasma_utils.PlasmaArray( + rng.choice( + len(self.dataset), + self.actual_size, + replace=self.replace, + p=(None if self.weights is None else self.weights.array), + ) + ) diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py new file mode 100644 index 0000000..aa33f9d --- /dev/null +++ b/fairseq/data/token_block_dataset.py @@ -0,0 +1,168 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from fairseq.data import FairseqDataset, plasma_utils + + +class TokenBlockDataset(FairseqDataset): + """Break a Dataset of tokens into blocks. + + Args: + dataset (~torch.utils.data.Dataset): dataset to break into blocks + sizes (List[int]): sentence lengths (required for 'complete' and 'eos') + block_size (int): maximum block size (ignored in 'eos' break mode) + break_mode (str, optional): Mode used for breaking tokens. Values can + be one of: + - 'none': break tokens into equally sized blocks (up to block_size) + - 'complete': break tokens into blocks (up to block_size) such that + blocks contains complete sentences, although block_size may be + exceeded if some sentences exceed block_size + - 'complete_doc': similar to 'complete' mode, but do not + cross document boundaries + - 'eos': each block contains one sentence (block_size is ignored) + include_targets (bool, optional): return next tokens as targets + (default: False). + document_sep_len (int, optional): document separator size (required for + 'complete_doc' break mode). Typically 1 if the sentences have eos + and 0 otherwise. + """ + + def __init__( + self, + dataset, + sizes, + block_size, + pad, + eos, + break_mode=None, + include_targets=False, + document_sep_len=1, + ): + try: + from fairseq.data.token_block_utils_fast import ( + _get_slice_indices_fast, + _get_block_to_dataset_index_fast, + ) + except ImportError: + raise ImportError( + "Please build Cython components with: `pip install --editable .` " + "or `python setup.py build_ext --inplace`" + ) + + super().__init__() + self.dataset = dataset + self.pad = pad + self.eos = eos + self.include_targets = include_targets + + assert len(dataset) == len(sizes) + assert len(dataset) > 0 + + if isinstance(sizes, list): + sizes = np.array(sizes, dtype=np.int64) + else: + if torch.is_tensor(sizes): + sizes = sizes.numpy() + sizes = sizes.astype(np.int64) + + break_mode = break_mode if break_mode is not None else "none" + + # For "eos" break-mode, block_size is not required parameters. + if break_mode == "eos" and block_size is None: + block_size = 0 + + slice_indices = _get_slice_indices_fast( + sizes, str(break_mode), block_size, document_sep_len + ) + self._sizes = slice_indices[:, 1] - slice_indices[:, 0] + + # build index mapping block indices to the underlying dataset indices + if break_mode == "eos": + # much faster version for eos break mode + block_to_dataset_index = np.stack( + [ + np.arange(len(sizes)), # starting index in dataset + np.zeros( + len(sizes), dtype=np.long + ), # starting offset within starting index + np.arange(len(sizes)), # ending index in dataset + ], + 1, + ) + else: + block_to_dataset_index = _get_block_to_dataset_index_fast( + sizes, + slice_indices, + ) + self._slice_indices = plasma_utils.PlasmaArray(slice_indices) + self._sizes = plasma_utils.PlasmaArray(self._sizes) + self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index) + + @property + def slice_indices(self): + return self._slice_indices.array + + @property + def sizes(self): + return self._sizes.array + + @property + def block_to_dataset_index(self): + return self._block_to_dataset_index.array + + def attr(self, attr: str, index: int): + start_ds_idx, _, _ = self.block_to_dataset_index[index] + return self.dataset.attr(attr, start_ds_idx) + + def __getitem__(self, index): + start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] + + buffer = torch.cat( + [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] + ) + + slice_s, slice_e = self.slice_indices[index] + length = slice_e - slice_s + s, e = start_offset, start_offset + length + item = buffer[s:e] + + if self.include_targets: + # *target* is the original sentence (=item) + # *source* is shifted right by 1 (maybe left-padded with eos) + # *past_target* is shifted right by 2 (left-padded as needed) + if s == 0: + source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]]) + past_target = torch.cat( + [item.new([self.pad, self.eos]), buffer[0 : e - 2]] + ) + else: + source = buffer[s - 1 : e - 1] + if s == 1: + past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]]) + else: + past_target = buffer[s - 2 : e - 2] + + return source, item, past_target + + return item + + def __len__(self): + return len(self.slice_indices) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + self.dataset.prefetch( + { + ds_idx + for index in indices + for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]] + for ds_idx in range(start_ds_idx, end_ds_idx + 1) + } + ) diff --git a/fairseq/data/token_block_utils_fast.pyx b/fairseq/data/token_block_utils_fast.pyx new file mode 100644 index 0000000..5a2f16e --- /dev/null +++ b/fairseq/data/token_block_utils_fast.pyx @@ -0,0 +1,187 @@ +# cython: language_level=3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from itertools import chain +from libc.math cimport ceil + +cimport cython +cimport numpy as np + +from libc.stdint cimport int32_t, int64_t + +DTYPE = np.int64 +ctypedef int64_t DTYPE_t + + +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size): + cdef DTYPE_t total_size = sizes.sum() + cdef DTYPE_t length = ceil(total_size / block_size) + cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE) + cdef DTYPE_t[:, :] slice_indices_view = slice_indices + cdef DTYPE_t i + cdef DTYPE_t start + cdef DTYPE_t end + for i in range(length): + start = i * block_size + end = min(start + block_size, total_size) + slice_indices_view[i][0] = start + slice_indices_view[i][1] = end + return slice_indices + + +cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list): + """ + Faster function to convert DTYPE_t list of list. + Only fast when there are huge number of rows and low number of columns. + """ + cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1) + return flat.reshape((len(list_of_list), -1)) + + +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len): + cdef DTYPE_t tok_idx = 0 + cdef DTYPE_t sz_idx = 0 + cdef DTYPE_t curr_size = 0 + cdef DTYPE_t i = 0 + cdef DTYPE_t length + cdef DTYPE_t total_size + cdef DTYPE_t[:] sizes_view = sizes + cdef np.ndarray[DTYPE_t, ndim=2] slice_indices + cdef list slice_indices_list = [] + + if break_mode is None or break_mode == 'none': + slice_indices = _get_slice_indices_none_mode(sizes, block_size) + elif break_mode == 'complete': + while sz_idx < len(sizes_view): + if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: + curr_size += sizes_view[sz_idx] + sz_idx += 1 + else: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + tok_idx += curr_size + curr_size = 0 + if curr_size > 0: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + slice_indices = _fast_convert_to_np_array(slice_indices_list) + elif break_mode == 'complete_doc': + while sz_idx < len(sizes_view): + if ( + (curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0) + # an empty sentence indicates end-of-document: + and sizes_view[sz_idx] != document_sep_len + ): + curr_size += sizes_view[sz_idx] + sz_idx += 1 + else: + # Only keep non-empty documents. + if curr_size > 1: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + tok_idx += curr_size + curr_size = 0 + if sizes_view[sz_idx] == document_sep_len: + tok_idx += sizes_view[sz_idx] + sz_idx += 1 + if curr_size > 1: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + slice_indices = _fast_convert_to_np_array(slice_indices_list) + elif break_mode == 'eos': + slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) + cumsum = sizes.cumsum(axis=0) + slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1] + slice_indices[:, 1] = cumsum + else: + raise ValueError('Invalid break_mode: ' + break_mode) + return slice_indices + + +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices): + cdef DTYPE_t start_ds_idx + cdef DTYPE_t start_offset + cdef DTYPE_t end_ds_idx + cdef DTYPE_t i + cdef DTYPE_t s + cdef DTYPE_t e + cdef DatasetSearcher ds = DatasetSearcher(sizes) + cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE) + cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index + cdef DTYPE_t[:, :] slice_indices_view = slice_indices + cdef Py_ssize_t x_max = slice_indices.shape[0] + + for i in range(x_max): + s = slice_indices_view[i][0] + e = slice_indices_view[i][1] + ds.seek(s) + start_ds_idx = ds.current_index + start_offset = ds.current_offset + if e <= s: + end_ds_idx = start_ds_idx + else: + ds.seek(e - 1) + end_ds_idx = ds.current_index + block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset + block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index + block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset + return block_to_dataset_index + + +cdef class DatasetSearcher(object): + """Helper for mapping "flat" indices to indices and offsets in an + underlying dataset.""" + cdef DTYPE_t current_i + cdef DTYPE_t current_offset + cdef DTYPE_t current_index + cdef DTYPE_t[:] sizes + + def __init__(self, DTYPE_t[:] sizes): + self.sizes = sizes + self.reset() + + cdef reset(self): + self.current_offset = 0 # offset within current index in underlying dataset + self.current_i = 0 # "flat" index + self.current_index = 0 # index in underlying dataset + + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.nonecheck(False) + cdef int step(self, DTYPE_t i): + cdef DTYPE_t to_consume + cdef DTYPE_t remaining + if i < self.current_i: + self.reset() + if i > self.current_i: + to_consume = i - self.current_i + remaining = self.sizes[self.current_index] - self.current_offset + if remaining > to_consume: + self.current_offset += to_consume + self.current_i += to_consume + else: + assert remaining > 0 + self.current_i += remaining + self.current_index += 1 + self.current_offset = 0 + return 1 + return 0 + + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.nonecheck(False) + cdef seek(self, DTYPE_t i): + cdef int not_done = 1 + while not_done == 1: + not_done = self.step(i) + assert self.current_i == i diff --git a/fairseq/dataclass/__init__.py b/fairseq/dataclass/__init__.py new file mode 100644 index 0000000..5c9004d --- /dev/null +++ b/fairseq/dataclass/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .configs import FairseqDataclass +from .constants import ChoiceEnum + + +__all__ = ["FairseqDataclass", "ChoiceEnum"] diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py new file mode 100644 index 0000000..484d252 --- /dev/null +++ b/fairseq/dataclass/configs.py @@ -0,0 +1,876 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys +from dataclasses import _MISSING_TYPE, dataclass, field +from typing import Any, List, Optional + +import torch + +from fairseq.dataclass.constants import ( + DATASET_IMPL_CHOICES, + DDP_BACKEND_CHOICES, + DISTRIBUTED_WRAPPER_CHOICES, + GENERATION_CONSTRAINTS_CHOICES, + GENERATION_DECODING_FORMAT_CHOICES, + LOG_FORMAT_CHOICES, + PIPELINE_CHECKPOINT_CHOICES, + ZERO_SHARDING_CHOICES, +) + +from omegaconf import II + + +@dataclass +class FairseqDataclass: + """fairseq base dataclass that supported fetching attributes and metas""" + + _name: Optional[str] = None + + @staticmethod + def name(): + return None + + def _get_all_attributes(self) -> List[str]: + return [k for k in self.__dataclass_fields__.keys()] + + def _get_meta( + self, attribute_name: str, meta: str, default: Optional[Any] = None + ) -> Any: + return self.__dataclass_fields__[attribute_name].metadata.get(meta, default) + + def _get_name(self, attribute_name: str) -> str: + return self.__dataclass_fields__[attribute_name].name + + def _get_default(self, attribute_name: str) -> Any: + if hasattr(self, attribute_name): + if str(getattr(self, attribute_name)).startswith("${"): + return str(getattr(self, attribute_name)) + elif str(self.__dataclass_fields__[attribute_name].default).startswith( + "${" + ): + return str(self.__dataclass_fields__[attribute_name].default) + elif ( + getattr(self, attribute_name) + != self.__dataclass_fields__[attribute_name].default + ): + return getattr(self, attribute_name) + + f = self.__dataclass_fields__[attribute_name] + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + def _get_type(self, attribute_name: str) -> Any: + return self.__dataclass_fields__[attribute_name].type + + def _get_help(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "help") + + def _get_argparse_const(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "argparse_const") + + def _get_argparse_alias(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "argparse_alias") + + def _get_choices(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "choices") + + +@dataclass +class CommonConfig(FairseqDataclass): + # This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were + # used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc. + no_progress_bar: bool = field( + default=False, metadata={"help": "disable progress bar"} + ) + log_interval: int = field( + default=100, + metadata={ + "help": "log progress every N batches (when progress bar is disabled)" + }, + ) + log_format: Optional[LOG_FORMAT_CHOICES] = field( + default=None, metadata={"help": "log format to use"} + ) + tensorboard_logdir: Optional[str] = field( + default=None, + metadata={ + "help": "path to save logs for tensorboard, should match --logdir " + "of running tensorboard (default: no tensorboard logging)" + }, + ) + seed: int = field( + default=1, metadata={"help": "pseudo random number generator seed"} + ) + cpu: bool = field(default=False, metadata={"help": "use CPU instead of CUDA"}) + tpu: bool = field(default=False, metadata={"help": "use TPU instead of CUDA"}) + bf16: bool = field(default=False, metadata={"help": "use bfloat16; implies --tpu"}) + memory_efficient_bf16: bool = field( + default=False, + metadata={ + "help": "use a memory-efficient version of BF16 training; implies --bf16" + }, + ) + fp16: bool = field(default=False, metadata={"help": "use FP16"}) + memory_efficient_fp16: bool = field( + default=False, + metadata={ + "help": "use a memory-efficient version of FP16 training; implies --fp16" + }, + ) + fp16_no_flatten_grads: bool = field( + default=False, metadata={"help": "don't flatten FP16 grads tensor"} + ) + fp16_init_scale: int = field( + default=2 ** 7, metadata={"help": "default FP16 loss scale"} + ) + fp16_scale_window: Optional[int] = field( + default=None, + metadata={"help": "number of updates before increasing loss scale"}, + ) + fp16_scale_tolerance: float = field( + default=0.0, + metadata={ + "help": "pct of updates that can overflow before decreasing the loss scale" + }, + ) + min_loss_scale: float = field( + default=1e-4, + metadata={"help": "minimum FP16 loss scale, after which training is stopped"}, + ) + threshold_loss_scale: Optional[float] = field( + default=None, metadata={"help": "threshold FP16 loss scale from below"} + ) + user_dir: Optional[str] = field( + default=None, + metadata={ + "help": "path to a python module containing custom extensions (tasks and/or architectures)" + }, + ) + empty_cache_freq: int = field( + default=0, + metadata={"help": "how often to clear the PyTorch CUDA cache (0 to disable)"}, + ) + all_gather_list_size: int = field( + default=16384, + metadata={"help": "number of bytes reserved for gathering stats from workers"}, + ) + model_parallel_size: int = field( + default=1, metadata={"help": "total number of GPUs to parallelize model over"} + ) + quantization_config_path: Optional[str] = field( + default=None, metadata={"help": "path to quantization config file"} + ) + profile: bool = field( + default=False, metadata={"help": "enable autograd profiler emit_nvtx"} + ) + + +@dataclass +class DistributedTrainingConfig(FairseqDataclass): + distributed_world_size: int = field( + default=max(1, torch.cuda.device_count()), + metadata={ + "help": "total number of GPUs across all nodes (default: all visible GPUs)" + }, + ) + distributed_rank: Optional[int] = field( + default=0, metadata={"help": "rank of the current worker"} + ) + distributed_backend: str = field( + default="nccl", metadata={"help": "distributed backend"} + ) + distributed_init_method: Optional[str] = field( + default=None, + metadata={ + "help": "typically tcp://hostname:port that will be used to " + "establish initial connetion" + }, + ) + distributed_port: int = field( + default=-1, + metadata={ + "help": "port number (not required if using --distributed-init-method)" + }, + ) + device_id: int = field( + default=0, + metadata={"help": "which GPU to use (usually configured automatically)"}, + ) + local_rank: int = field( + default=0, + metadata={"help": "which GPU to use (usually configured automatically)"}, + ) + distributed_no_spawn: bool = field( + default=False, + metadata={ + "help": "do not spawn multiple processes even if multiple GPUs are visible" + }, + ) + ddp_backend: DDP_BACKEND_CHOICES = field( + default="c10d", metadata={"help": "DistributedDataParallel backend"} + ) + bucket_cap_mb: int = field( + default=25, metadata={"help": "bucket size for reduction"} + ) + fix_batches_to_gpus: bool = field( + default=False, + metadata={ + "help": "don't shuffle batches between GPUs; this reduces overall " + "randomness and may affect precision but avoids the cost of re-reading the data" + }, + ) + find_unused_parameters: bool = field( + default=False, + metadata={ + "help": "disable unused parameter detection (not applicable to " + "no_c10d ddp-backend" + }, + ) + fast_stat_sync: bool = field( + default=False, + metadata={"help": "[deprecated] this is now defined per Criterion"}, + ) + broadcast_buffers: bool = field( + default=False, + metadata={ + "help": "Copy non-trainable parameters between GPUs, such as " + "batchnorm population statistics" + }, + ) + distributed_wrapper: DISTRIBUTED_WRAPPER_CHOICES = field( + default="DDP", metadata={"help": "DistributedDataParallel backend"} + ) + slowmo_momentum: Optional[float] = field( + default=None, + metadata={ + "help": "SlowMo momentum term; by default use 0.0 for 16 GPUs, " + "0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs" + }, + ) + slowmo_algorithm: str = field( + default="LocalSGD", metadata={"help": "whether to use LocalSGD or SGP"} + ) + localsgd_frequency: int = field( + default=3, metadata={"help": "Local SGD allreduce frequency"} + ) + nprocs_per_node: int = field( + default=max(1, torch.cuda.device_count()), + metadata={ + "help": "number of GPUs in each node. An allreduce operation across GPUs in " + "a node is very fast. Hence, we do allreduce across GPUs in a node, " + "and gossip across different nodes" + }, + ) + pipeline_model_parallel: bool = field( + default=False, + metadata={"help": "if set, use pipeline model parallelism across GPUs"}, + ) + pipeline_balance: Optional[str] = field( + default=None, + metadata={ + "help": "partition the model into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_balance) " + "should equal the total number of layers in the model" + }, + ) + pipeline_devices: Optional[str] = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-balance argument" + }, + ) + pipeline_chunks: Optional[int] = field( + default=0, metadata={"help": "microbatch count for pipeline model parallelism"} + ) + pipeline_encoder_balance: Optional[str] = field( + default=None, + metadata={ + "help": "partition the pipeline parallel encoder into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_encoder_balance) " + "should equal the total number of encoder layers in the model" + }, + ) + pipeline_encoder_devices: Optional[str] = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-encoder-balance argument" + }, + ) + pipeline_decoder_balance: Optional[str] = field( + default=None, + metadata={ + "help": "partition the pipeline parallel decoder into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_decoder_balance) " + "should equal the total number of decoder layers in the model" + }, + ) + pipeline_decoder_devices: Optional[str] = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-decoder-balance argument" + }, + ) + pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( + default="never", + metadata={"help": "checkpointing mode for pipeline model parallelism"}, + ) + zero_sharding: ZERO_SHARDING_CHOICES = field( + default="none", metadata={"help": "ZeRO sharding"} + ) + tpu: bool = II("common.tpu") + + +@dataclass +class DatasetConfig(FairseqDataclass): + num_workers: int = field( + default=1, metadata={"help": "how many subprocesses to use for data loading"} + ) + skip_invalid_size_inputs_valid_test: bool = field( + default=False, + metadata={"help": "ignore too long or too short lines in valid and test set"}, + ) + max_tokens: Optional[int] = field( + default=None, metadata={"help": "maximum number of tokens in a batch"} + ) + batch_size: Optional[int] = field( + default=None, + metadata={ + "help": "number of examples in a batch", + "argparse_alias": "--max-sentences", + }, + ) + required_batch_size_multiple: int = field( + default=8, metadata={"help": "batch size will be a multiplier of this value"} + ) + required_seq_len_multiple: int = field( + default=1, + metadata={ + "help": "maximum sequence length in batch will be a multiplier of this value" + }, + ) + dataset_impl: Optional[DATASET_IMPL_CHOICES] = field( + default=None, metadata={"help": "output dataset implementation"} + ) + data_buffer_size: int = field( + default=10, metadata={"help": "Number of batches to preload"} + ) + train_subset: str = field( + default="train", + metadata={"help": "data subset to use for training (e.g. train, valid, test)"}, + ) + valid_subset: str = field( + default="valid", + metadata={ + "help": "comma separated list of data subsets to use for validation" + " (e.g. train, valid, test)" + }, + ) + validate_interval: int = field( + default=1, metadata={"help": "validate every N epochs"} + ) + validate_interval_updates: int = field( + default=0, metadata={"help": "validate every N updates"} + ) + validate_after_updates: int = field( + default=0, metadata={"help": "dont validate until reaching this many updates"} + ) + fixed_validation_seed: Optional[int] = field( + default=None, metadata={"help": "specified random seed for validation"} + ) + disable_validation: bool = field( + default=False, metadata={"help": "disable validation"} + ) + max_tokens_valid: Optional[int] = field( + default=None, + metadata={ + "help": "maximum number of tokens in a validation batch" + " (defaults to --max-tokens)" + }, + ) + batch_size_valid: Optional[int] = field( + default=None, + metadata={ + "help": "batch size of the validation batch (defaults to --batch-size)", + "argparse_alias": "--max-sentences-valid", + }, + ) + curriculum: int = field( + default=0, metadata={"help": "don't shuffle batches for first N epochs"} + ) + gen_subset: str = field( + default="test", + metadata={"help": "data subset to generate (train, valid, test)"}, + ) + num_shards: int = field( + default=1, metadata={"help": "shard generation over N shards"} + ) + shard_id: int = field( + default=0, metadata={"help": "id of the shard to generate (id < num_shards)"} + ) + + +@dataclass +class OptimizationConfig(FairseqDataclass): + max_epoch: int = field( + default=0, metadata={"help": "force stop training at specified epoch"} + ) + max_update: int = field( + default=0, metadata={"help": "force stop training at specified update"} + ) + stop_time_hours: float = field( + default=0, + metadata={ + "help": "force stop training after specified cumulative time (if >0)" + }, + ) + clip_norm: float = field( + default=25.0, metadata={"help": "clip threshold of gradients"} + ) + sentence_avg: bool = field( + default=False, + metadata={ + "help": "normalize gradients by the number of sentences in a batch" + " (default is to normalize by number of tokens)" + }, + ) + update_freq: List[int] = field( + default_factory=lambda: [1], + metadata={"help": "update parameters every N_i batches, when in epoch i"}, + ) + lr: List[float] = field( + default_factory=lambda: [0.25], + metadata={ + "help": "learning rate for the first N epochs; all epochs >N using LR_N" + " (note: this may be interpreted differently depending on --lr-scheduler)" + }, + ) + min_lr: float = field( + default=-1.0, + metadata={"help": "stop training when the learning rate reaches this minimum"}, + ) + use_bmuf: bool = field( + default=False, + metadata={ + "help": "specify global optimizer for syncing models on different GPUs/shards" + }, + ) + + +@dataclass +class CheckpointConfig(FairseqDataclass): + save_dir: str = field( + default="checkpoints", metadata={"help": "path to save checkpoints"} + ) + restore_file: str = field( + default="checkpoint_last.pt", + metadata={ + "help": "filename from which to load checkpoint " + "(default: /checkpoint_last.pt" + }, + ) + finetune_from_model: Optional[str] = field( + default=None, + metadata={ + "help": "finetune from a pretrained model; note that meters and lr scheduler will be reset" + }, + ) + reset_dataloader: bool = field( + default=False, + metadata={ + "help": "if set, does not reload dataloader state from the checkpoint" + }, + ) + reset_lr_scheduler: bool = field( + default=False, + metadata={ + "help": "if set, does not load lr scheduler state from the checkpoint" + }, + ) + reset_meters: bool = field( + default=False, + metadata={"help": "if set, does not load meters from the checkpoint"}, + ) + reset_optimizer: bool = field( + default=False, + metadata={"help": "if set, does not load optimizer state from the checkpoint"}, + ) + optimizer_overrides: str = field( + default="{}", + metadata={ + "help": "a dictionary used to override optimizer args when loading a checkpoint" + }, + ) + save_interval: int = field( + default=1, metadata={"help": "save a checkpoint every N epochs"} + ) + save_interval_updates: int = field( + default=0, metadata={"help": "save a checkpoint (and validate) every N updates"} + ) + keep_interval_updates: int = field( + default=-1, + metadata={ + "help": "keep the last N checkpoints saved with --save-interval-updates" + }, + ) + keep_last_epochs: int = field( + default=-1, metadata={"help": "keep last N epoch checkpoints"} + ) + keep_best_checkpoints: int = field( + default=-1, metadata={"help": "keep best N checkpoints based on scores"} + ) + no_save: bool = field( + default=False, metadata={"help": "don't save models or checkpoints"} + ) + no_epoch_checkpoints: bool = field( + default=False, metadata={"help": "only store last and best checkpoints"} + ) + no_last_checkpoints: bool = field( + default=False, metadata={"help": "don't store last checkpoints"} + ) + no_save_optimizer_state: bool = field( + default=False, + metadata={"help": "don't save optimizer-state as part of checkpoint"}, + ) + best_checkpoint_metric: str = field( + default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'} + ) + maximize_best_checkpoint_metric: bool = field( + default=False, + metadata={ + "help": 'select the largest metric value for saving "best" checkpoints' + }, + ) + patience: int = field( + default=-1, + metadata={ + "help": ( + "early stop training if valid performance doesn't " + "improve for N consecutive validation runs; note " + "that this is influenced by --validate-interval" + ) + }, + ) + checkpoint_suffix: str = field( + default="", metadata={"help": "suffix to add to the checkpoint file name"} + ) + checkpoint_shard_count: int = field( + default=1, + metadata={ + "help": "Number of shards containing the checkpoint - " + "if the checkpoint is over 300GB, it is preferable " + "to split it into shards to prevent OOM on CPU while loading " + "the checkpoint" + }, + ) + model_parallel_size: int = II("common.model_parallel_size") + distributed_rank: int = II("distributed_training.distributed_rank") + + +@dataclass +class FairseqBMUFConfig(FairseqDataclass): + block_lr: float = field( + default=1, metadata={"help": "block learning rate for bmuf"} + ) + block_momentum: float = field( + default=0.875, metadata={"help": "block momentum for bmuf"} + ) + global_sync_iter: int = field( + default=50, metadata={"help": "Iteration for syncing global model"} + ) + warmup_iterations: int = field( + default=500, metadata={"help": "warmup iterations for model to broadcast"} + ) + use_nbm: bool = field( + default=False, + metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"}, + ) + average_sync: bool = field( + default=False, + metadata={ + "help": "Specify whether you want to average the local momentum after each sync" + }, + ) + distributed_world_size: int = II("distributed_training.distributed_world_size") + + +@dataclass +class GenerationConfig(FairseqDataclass): + beam: int = field( + default=5, + metadata={"help": "beam size"}, + ) + nbest: int = field( + default=1, + metadata={"help": "number of hypotheses to output"}, + ) + max_len_a: float = field( + default=0, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + max_len_b: int = field( + default=200, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + min_len: int = field( + default=1, + metadata={"help": "minimum generation length"}, + ) + match_source_len: bool = field( + default=False, + metadata={"help": "generations should match the source length"}, + ) + unnormalized: bool = field( + default=False, + metadata={"help": "compare unnormalized hypothesis scores"}, + ) + no_early_stop: bool = field( + default=False, + metadata={"help": "deprecated"}, + ) + no_beamable_mm: bool = field( + default=False, + metadata={"help": "don't use BeamableMM in attention layers"}, + ) + lenpen: float = field( + default=1, + metadata={ + "help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences" + }, + ) + unkpen: float = field( + default=0, + metadata={ + "help": "unknown word penalty: <0 produces more unks, >0 produces fewer" + }, + ) + replace_unk: Optional[str] = field( + default=None, + metadata={ + "help": "perform unknown replacement (optionally with alignment dictionary)", + "argparse_const": "@@ ", + }, + ) + sacrebleu: bool = field( + default=False, + metadata={"help": "score with sacrebleu"}, + ) + score_reference: bool = field( + default=False, + metadata={"help": "just score the reference translation"}, + ) + prefix_size: int = field( + default=0, + metadata={"help": "initialize generation by target prefix of given length"}, + ) + no_repeat_ngram_size: int = field( + default=0, + metadata={ + "help": "ngram blocking such that this size ngram cannot be repeated in the generation" + }, + ) + sampling: bool = field( + default=False, + metadata={"help": "sample hypotheses instead of using beam search"}, + ) + sampling_topk: int = field( + default=-1, + metadata={"help": "sample from top K likely next words instead of all words"}, + ) + sampling_topp: float = field( + default=-1.0, + metadata={ + "help": "sample from the smallest set whose cumulative probability mass exceeds p for next words" + }, + ) + constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field( + default=None, + metadata={ + "help": "enables lexically constrained decoding", + "argparse_const": "ordered", + }, + ) + temperature: float = field( + default=1.0, + metadata={"help": "temperature for generation"}, + ) + diverse_beam_groups: int = field( + default=-1, + metadata={"help": "number of groups for Diverse Beam Search"}, + ) + diverse_beam_strength: float = field( + default=0.5, + metadata={"help": "strength of diversity penalty for Diverse Beam Search"}, + ) + diversity_rate: float = field( + default=-1.0, + metadata={"help": "strength of diversity penalty for Diverse Siblings Search"}, + ) + print_alignment: bool = field( + default=False, + metadata={ + "help": "if set, uses attention feedback to compute and print alignment to source tokens" + }, + ) + print_step: bool = field( + default=False, + metadata={"help": "print steps"}, + ) + lm_path: Optional[str] = field( + default=None, + metadata={"help": "path to lm checkpoint for lm fusion"}, + ) + lm_weight: float = field( + default=0.0, + metadata={"help": "weight for lm probs for lm fusion"}, + ) + + # arguments for iterative refinement generator + iter_decode_eos_penalty: float = field( + default=0.0, + metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, + ) + iter_decode_max_iter: int = field( + default=10, + metadata={"help": "maximum iterations for iterative refinement."}, + ) + iter_decode_force_max_iter: bool = field( + default=False, + metadata={ + "help": "if set, run exact the maximum number of iterations without early stop" + }, + ) + iter_decode_with_beam: int = field( + default=1, + metadata={ + "help": "if > 1, model will generate translations varying by the lengths." + }, + ) + iter_decode_with_external_reranker: bool = field( + default=False, + metadata={ + "help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations" + }, + ) + retain_iter_history: bool = field( + default=False, + metadata={ + "help": "if set, decoding returns the whole history of iterative refinement" + }, + ) + retain_dropout: bool = field( + default=False, + metadata={"help": "Use dropout at inference time"}, + ) + retain_dropout_modules: Optional[List[str]] = field( + default=None, + metadata={ + "help": "if set, only retain dropout for the specified modules; " + "if not set, then dropout will be retained for all modules" + }, + ) + # special decoding format for advanced decoding. + decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field( + default=None, + metadata={"help": "special decoding format for advanced decoding."}, + ) + no_seed_provided: bool = field( + default=False, + metadata={"help": "if set, dont use seed for initializing random generators"}, + ) + + +@dataclass +class CommonEvalConfig(FairseqDataclass): + path: Optional[str] = field( + default=None, + metadata={"help": "path(s) to model file(s), colon separated"}, + ) + post_process: Optional[str] = field( + default=None, + metadata={ + "help": "post-process text by removing pre-processing such as BPE, letter segmentation, etc " + "(valid options are: sentencepiece, wordpiece, letter, _EOW, none, otherwise treated as BPE symbol)", + "argparse_const": "@@ ", + "argparse_alias": "--remove-bpe", + }, + ) + quiet: bool = field(default=False, metadata={"help": "only print final scores"}) + model_overrides: str = field( + default="{}", + metadata={ + "help": "a dictionary used to override model args at generation that were used during model training" + }, + ) + results_path: Optional[str] = field( + default=None, metadata={"help": "path to save eval results (optional)"} + ) + + +@dataclass +class EvalLMConfig(FairseqDataclass): + output_word_probs: bool = field( + default=False, + metadata={ + "help": "if set, outputs words and their predicted log probabilities to standard output" + }, + ) + output_word_stats: bool = field( + default=False, + metadata={ + "help": "if set, outputs word statistics such as word count, average probability, etc" + }, + ) + context_window: int = field( + default=0, + metadata={ + "help": "ensures that every evaluated token has access to a context of at least this size, if possible" + }, + ) + softmax_batch: int = field( + default=sys.maxsize, + metadata={ + "help": "if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory" + }, + ) + + +@dataclass +class InteractiveConfig(FairseqDataclass): + buffer_size: int = field( + default=0, + metadata={ + "help": "read this many sentences into a buffer before processing them" + }, + ) + input: str = field( + default="-", + metadata={"help": "file to read from; use - for stdin"}, + ) + + +@dataclass +class FairseqConfig(object): + common: CommonConfig = CommonConfig() + common_eval: CommonEvalConfig = CommonEvalConfig() + distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() + dataset: DatasetConfig = DatasetConfig() + optimization: OptimizationConfig = OptimizationConfig() + checkpoint: CheckpointConfig = CheckpointConfig() + bmuf: FairseqBMUFConfig = FairseqBMUFConfig() + generation: GenerationConfig = GenerationConfig() + eval_lm: EvalLMConfig = EvalLMConfig() + interactive: InteractiveConfig = InteractiveConfig() diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py new file mode 100644 index 0000000..3eb63ec --- /dev/null +++ b/fairseq/dataclass/constants.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import List + + +class StrEnum(Enum): + def __str__(self): + return self.value + + def __eq__(self, other: str): + return self.value == other + + def __repr__(self): + return self.value + + +def ChoiceEnum(choices: List[str]): + """return the Enum class used to enforce list of choices""" + return StrEnum("Choices", {k: k for k in choices}) + + +LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) +DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) +DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta"]) +DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) +GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) +GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum( + ["unigram", "ensemble", "vote", "dp", "bs"] +) +ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) +PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py new file mode 100644 index 0000000..b762af9 --- /dev/null +++ b/fairseq/dataclass/initialize.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import Dict, Any + +from hydra.core.config_store import ConfigStore + +from fairseq.dataclass.configs import FairseqConfig + +from fairseq.models import MODEL_DATACLASS_REGISTRY +from fairseq.tasks import TASK_DATACLASS_REGISTRY +from fairseq.registry import REGISTRIES + + +logger = logging.getLogger(__name__) + + +def register_module_dataclass( + cs: ConfigStore, registry: Dict[str, Any], group: str +) -> None: + """register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc.""" + # note that if `group == model`, we register all model archs, not the model name. + for k, v in registry.items(): + node_ = v() + node_._name = k + cs.store(name=k, group=group, node=node_, provider="fairseq") + + +def hydra_init() -> None: + cs = ConfigStore.instance() + + for k in FairseqConfig.__dataclass_fields__: + v = FairseqConfig.__dataclass_fields__[k].default + try: + cs.store(name=k, node=v) + except BaseException: + logger.error(f"{k} - {v}") + raise + + register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") + register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model") + + for k, v in REGISTRIES.items(): + register_module_dataclass(cs, v["dataclass_registry"], k) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py new file mode 100644 index 0000000..25d350c --- /dev/null +++ b/fairseq/dataclass/utils.py @@ -0,0 +1,363 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import os +from argparse import ArgumentError, ArgumentParser, Namespace +from dataclasses import _MISSING_TYPE, MISSING +from enum import Enum +from typing import Any, Dict, List, Tuple, Type + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import FairseqConfig +from hydra.core.global_hydra import GlobalHydra +from hydra.experimental import compose, initialize +from omegaconf import DictConfig, OmegaConf, open_dict + + +def eval_str_list(x, x_type=float): + if x is None: + return None + if isinstance(x, str): + if len(x) == 0: + return [] + x = ast.literal_eval(x) + try: + return list(map(x_type, x)) + except TypeError: + return [x_type(x)] + + +def gen_parser_from_dataclass( + parser: ArgumentParser, + dataclass_instance: FairseqDataclass, + delete_default: bool = False, +) -> None: + """convert a dataclass instance to tailing parser arguments""" + import re + + def argparse_name(name: str): + if name == "data": + # normally data is positional args + return name + if name == "_name": + # private member, skip + return None + if name != "local_rank": + return "--" + name.replace("_", "-") + else: + return "--" + name + + def interpret_dc_type(field_type): + if isinstance(field_type, str): + raise RuntimeError("field should be a type") + typestring = str(field_type) + if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring): + return field_type.__args__[0] + return field_type + + def get_kwargs_from_dc( + dataclass_instance: FairseqDataclass, k: str + ) -> Dict[str, Any]: + """k: dataclass attributes""" + + kwargs = {} + + field_type = dataclass_instance._get_type(k) + inter_type = interpret_dc_type(field_type) + + field_default = dataclass_instance._get_default(k) + + if isinstance(inter_type, type) and issubclass(inter_type, Enum): + field_choices = [t.value for t in list(inter_type)] + else: + field_choices = None + + field_help = dataclass_instance._get_help(k) + field_const = dataclass_instance._get_argparse_const(k) + + if isinstance(field_default, str) and field_default.startswith("${"): + kwargs["default"] = field_default + else: + if field_default is MISSING: + kwargs["required"] = True + if field_choices is not None: + kwargs["choices"] = field_choices + if (isinstance(inter_type, type) and issubclass(inter_type, List)) or ( + "List" in str(inter_type) + ): + if "int" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, int) + elif "float" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, float) + elif "str" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, str) + else: + raise NotImplementedError() + if field_default is not MISSING: + kwargs["default"] = ( + ",".join(map(str, field_default)) + if field_default is not None + else None + ) + elif ( + isinstance(inter_type, type) and issubclass(inter_type, Enum) + ) or "Enum" in str(inter_type): + kwargs["type"] = str + if field_default is not MISSING: + if isinstance(field_default, Enum): + kwargs["default"] = field_default.value + else: + kwargs["default"] = field_default + elif inter_type is bool: + kwargs["action"] = ( + "store_false" if field_default is True else "store_true" + ) + kwargs["default"] = field_default + else: + kwargs["type"] = inter_type + if field_default is not MISSING: + kwargs["default"] = field_default + + kwargs["help"] = field_help + if field_const is not None: + kwargs["const"] = field_const + kwargs["nargs"] = "?" + + return kwargs + + for k in dataclass_instance._get_all_attributes(): + field_name = argparse_name(dataclass_instance._get_name(k)) + if field_name is None: + continue + kwargs = get_kwargs_from_dc(dataclass_instance, k) + + field_args = [field_name] + alias = dataclass_instance._get_argparse_alias(k) + if alias is not None: + field_args.append(alias) + + if "default" in kwargs: + if isinstance(kwargs["default"], str) and kwargs["default"].startswith( + "${" + ): + continue + if delete_default: + del kwargs["default"] + try: + parser.add_argument(*field_args, **kwargs) + except ArgumentError: + pass + + +def _set_legacy_defaults(args, cls): + """Helper to set default arguments based on *add_args*.""" + if not hasattr(cls, "add_args"): + return + + import argparse + + parser = argparse.ArgumentParser( + argument_default=argparse.SUPPRESS, allow_abbrev=False + ) + cls.add_args(parser) + # copied from argparse.py: + defaults = argparse.Namespace() + for action in parser._actions: + if action.dest is not argparse.SUPPRESS: + if not hasattr(defaults, action.dest): + if action.default is not argparse.SUPPRESS: + setattr(defaults, action.dest, action.default) + for key, default_value in vars(defaults).items(): + if not hasattr(args, key): + setattr(args, key, default_value) + + +def _override_attr( + sub_node: str, data_class: Type[FairseqDataclass], args: Namespace +) -> List[str]: + overrides = [] + + def get_default(f): + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + for k, v in data_class.__dataclass_fields__.items(): + if k.startswith("_"): + # private member, skip + continue + + val = get_default(v) if not hasattr(args, k) else getattr(args, k) + + if val is None: + overrides.append("{}.{}=null".format(sub_node, k)) + elif val == "": + overrides.append("{}.{}=''".format(sub_node, k)) + elif isinstance(val, str): + overrides.append("{}.{}='{}'".format(sub_node, k, val)) + else: + overrides.append("{}.{}={}".format(sub_node, k, val)) + return overrides + + +def migrate_registry( + name, value, registry, args, overrides, deletes, use_name_as_val=False +): + if value in registry: + overrides.append("{}={}".format(name, value)) + overrides.append("{}._name={}".format(name, value)) + overrides.extend(_override_attr(name, registry[value], args)) + elif use_name_as_val and value is not None: + overrides.append("{}={}".format(name, value)) + else: + deletes.append(name) + + +def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: + """use the field in args to overrides those in cfg""" + overrides = [] + deletes = [] + + for k in FairseqConfig.__dataclass_fields__.keys(): + overrides.extend( + _override_attr(k, FairseqConfig.__dataclass_fields__[k].type, args) + ) + + if args is not None: + if hasattr(args, "task"): + from fairseq.tasks import TASK_DATACLASS_REGISTRY + + migrate_registry( + "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes + ) + else: + deletes.append("task") + + # these options will be set to "None" if they have not yet been migrated + # so we can populate them with the entire flat args + CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"} + + from fairseq.registry import REGISTRIES + + for k, v in REGISTRIES.items(): + if hasattr(args, k): + migrate_registry( + k, + getattr(args, k), + v["dataclass_registry"], + args, + overrides, + deletes, + use_name_as_val=k not in CORE_REGISTRIES, + ) + else: + deletes.append(k) + + no_dc = True + if hasattr(args, "arch"): + from fairseq.models import ARCH_MODEL_REGISTRY + + if args.arch in ARCH_MODEL_REGISTRY: + m_cls = ARCH_MODEL_REGISTRY[args.arch] + dc = getattr(m_cls, "__dataclass", None) + if dc is not None: + overrides.append("model={}".format(args.arch)) + overrides.append("model._name={}".format(args.arch)) + # override model params with those exist in args + overrides.extend(_override_attr("model", dc, args)) + no_dc = False + if no_dc: + deletes.append("model") + + return overrides, deletes + + +def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: + """Convert a flat argparse.Namespace to a structured DictConfig.""" + + # Here we are using field values provided in args to override counterparts inside config object + overrides, deletes = override_module_args(args) + + # configs will be in fairseq/config after installation + config_path = os.path.join("..", "config") + if not os.path.exists(config_path): + # in case of "--editable" installs we need to go one dir up + config_path = os.path.join("..", "..", "config") + + with initialize(config_path=config_path, strict=True): + composed_cfg = compose("config", overrides=overrides, strict=False) + for k in deletes: + composed_cfg[k] = None + + cfg = OmegaConf.create( + OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True) + ) + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import _utils + + old_primitive = _utils.is_primitive_type + _utils.is_primitive_type = lambda _: True + + if cfg.task is None and getattr(args, "task", None): + cfg.task = Namespace(**vars(args)) + from fairseq.tasks import TASK_REGISTRY + + _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) + cfg.task._name = args.task + if cfg.model is None and getattr(args, "arch", None): + cfg.model = Namespace(**vars(args)) + from fairseq.models import ARCH_MODEL_REGISTRY + + _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) + cfg.model._name = args.arch + if cfg.optimizer is None and getattr(args, "optimizer", None): + cfg.optimizer = Namespace(**vars(args)) + from fairseq.optim import OPTIMIZER_REGISTRY + + _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) + cfg.optimizer._name = args.optimizer + if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): + cfg.lr_scheduler = Namespace(**vars(args)) + from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY + + _set_legacy_defaults(cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler]) + cfg.lr_scheduler._name = args.lr_scheduler + if cfg.criterion is None and getattr(args, "criterion", None): + cfg.criterion = Namespace(**vars(args)) + from fairseq.criterions import CRITERION_REGISTRY + + _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) + cfg.criterion._name = args.criterion + + _utils.is_primitive_type = old_primitive + OmegaConf.set_struct(cfg, True) + return cfg + + +def populate_dataclass( + args: Namespace, dataclass: FairseqDataclass +) -> FairseqDataclass: + for k in dataclass.__dataclass_fields__.keys(): + if k.startswith("_"): + # private member, skip + continue + if hasattr(args, k): + setattr(dataclass, k, getattr(args, k)) + + return dataclass + + +def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): + # this will be deprecated when we get rid of argparse and model_overrides logic + + with open_dict(cfg): + for k in cfg.keys(): + if isinstance(cfg[k], DictConfig): + overwrite_args_by_name(cfg[k], overrides) + elif k in overrides: + cfg[k] = overrides[k] diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py new file mode 100644 index 0000000..5cadba5 --- /dev/null +++ b/fairseq/distributed_utils.py @@ -0,0 +1,458 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import pickle +import random +import socket +import struct +import subprocess +import warnings +from argparse import Namespace +from collections import OrderedDict +from typing import Any, Dict, Mapping + +import torch +import torch.distributed as dist +from fairseq import utils, pdb +from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig +from omegaconf import open_dict + + +logger = logging.getLogger(__name__) + + +def is_master(cfg: DistributedTrainingConfig): + return cfg.distributed_rank == 0 + + +def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): + if cfg.distributed_init_method is not None or cfg.tpu: + return + + if cfg.pipeline_model_parallel: + balance_exists = ( + cfg.pipeline_balance is not None + or cfg.pipeline_encoder_balance is not None + or cfg.pipeline_decoder_balance is not None + ) + devices_exist = ( + cfg.pipeline_devices is not None + or cfg.pipeline_encoder_devices is not None + or cfg.pipeline_decoder_devices is not None + ) + if not balance_exists: + raise ValueError( + "--pipeline-balance is currently required for pipeline model parallelism" + ) + if not devices_exist: + raise ValueError( + "--pipeline-devices is currently required for pipeline model parallelism" + ) + + cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int) + if cfg.pipeline_devices is not None: + cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int) + num_pipeline_devices = len(set(cfg.pipeline_devices)) + else: + cfg.pipeline_encoder_devices = utils.eval_str_list( + cfg.pipeline_encoder_devices, type=int + ) + cfg.pipeline_decoder_devices = utils.eval_str_list( + cfg.pipeline_decoder_devices, type=int + ) + num_pipeline_devices = len( + set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices) + ) + gpus_per_node = torch.cuda.device_count() + assert ( + gpus_per_node >= num_pipeline_devices + and gpus_per_node % num_pipeline_devices == 0 + ), ( + "the number of unique device IDs in --pipeline-devices must evenly divide " + "the number of GPUs per node (multi-node pipelining is not yet supported)" + ) + num_pipelines_per_node = gpus_per_node // num_pipeline_devices + + # support torch.distributed.launch + if all( + key in os.environ + for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] + ): + cfg.distributed_init_method = "env://" #tcp://" + os.environ["MASTER_ADDR"] + ":" + os.environ['MASTER_PORT'] + cfg.distributed_world_size = int(os.environ["WORLD_SIZE"]) + cfg.distributed_rank = int(os.environ["RANK"]) + # processes are created by torch.distributed.launch + cfg.distributed_no_spawn = True + + # we can determine the init method automatically for Slurm + elif cfg.distributed_port > 0: + node_list = os.environ.get("SLURM_STEP_NODELIST") + if node_list is None: + node_list = os.environ.get("SLURM_JOB_NODELIST") + if node_list is not None: + try: + hostnames = subprocess.check_output( + ["scontrol", "show", "hostnames", node_list] + ) + cfg.distributed_init_method = "tcp://{host}:{port}".format( + host=hostnames.split()[0].decode("utf-8"), + port=cfg.distributed_port, + ) + nnodes = int(os.environ.get("SLURM_NNODES")) + ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") + if ntasks_per_node is not None: + ntasks_per_node = int(ntasks_per_node) + else: + ntasks = int(os.environ.get("SLURM_NTASKS")) + nnodes = int(os.environ.get("SLURM_NNODES")) + assert ntasks % nnodes == 0 + ntasks_per_node = int(ntasks / nnodes) + if ntasks_per_node == 1: + gpus_per_node = torch.cuda.device_count() + node_id = int(os.environ.get("SLURM_NODEID")) + cfg.distributed_rank = node_id * gpus_per_node + cfg.distributed_world_size = nnodes * gpus_per_node + elif cfg.pipeline_model_parallel: + assert ntasks_per_node == num_pipelines_per_node, ( + "SLURM --ntasks-per-node must match number of pipelines per " + "node (={})".format(num_pipelines_per_node) + ) + cfg.distributed_no_spawn = True + # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on + # the first node, [1, 2] on the second node, etc. This + # matches torch.distributed.launch. + node_id = int(os.environ.get("SLURM_NODEID")) + local_id = int(os.environ.get("SLURM_LOCALID")) + cfg.distributed_rank = node_id * num_pipelines_per_node + local_id + # In the above example, local_rank will always be in [0, 1], + # which also matches torch.distributed.launch. + cfg.local_rank = local_id + # We also want to set distributed_world_size to be the total + # number of pipelines across all nodes. + cfg.distributed_world_size = nnodes * num_pipelines_per_node + else: + assert ntasks_per_node == cfg.distributed_world_size // nnodes + cfg.distributed_no_spawn = True + cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) + cfg.local_rank = int(os.environ.get("SLURM_LOCALID")) + except subprocess.CalledProcessError as e: # scontrol failed + raise e + except FileNotFoundError: # Slurm is not installed + pass + + elif cfg.distributed_world_size > 1 or force_distributed: + # fallback for single node with multiple GPUs + assert cfg.distributed_world_size <= torch.cuda.device_count() + port = random.randint(10000, 20000) + cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) + + if cfg.pipeline_model_parallel: + if not cfg.distributed_no_spawn: + # When distributed_no_spawn is False, we expect distributed_rank and + # distributed_world_size to be based on the total number of GPUs, so + # we need to correct them to be based on the number of pipelines. + assert cfg.distributed_world_size % num_pipeline_devices == 0 + cfg.distributed_world_size = ( + cfg.distributed_world_size // num_pipeline_devices + ) + # In the case of 4-way MP on nodes with 8 GPUs, we want + # distributed_rank to be the starting GPU index for each pipeline + # i.e., 0, 2, ... + assert cfg.distributed_rank % gpus_per_node == 0 + assert cfg.distributed_rank % num_pipeline_devices == 0 + + with open_dict(cfg): + cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices + # launch one process per pipeline + cfg.distributed_num_procs = num_pipelines_per_node + + # if we have 4-way MP on a node with 8 GPUs, we want local_ranks to be 0 + # and 4, indicating the starting device IDs for each pipeline + cfg.local_rank *= num_pipeline_devices + + if cfg.local_rank > 0: + # if there's multiple pipelines on a node (e.g., 4-way MP on an 8 + # GPU node), we need to adjust pipeline_devices accordingly + logger.debug( + "setting CUDA device={} on rank {}".format( + cfg.local_rank, cfg.distributed_rank + ) + ) + torch.cuda.set_device(cfg.local_rank) + with open_dict(cfg): + cfg.pipeline_devices = [cfg.local_rank + d for d in cfg.pipeline_devices] + logger.info( + "setting pipeline_devices={} on rank {}".format( + cfg.pipeline_devices, cfg.distributed_rank + ) + ) + elif not cfg.distributed_no_spawn: + with open_dict(cfg): + cfg.distributed_num_procs = min( + torch.cuda.device_count(), cfg.distributed_world_size + ) + + +def distributed_init(cfg: FairseqConfig): + if isinstance(cfg, Namespace): + from fairseq.dataclass.utils import convert_namespace_to_omegaconf + + cfg = convert_namespace_to_omegaconf(cfg) + + if not cfg.common.tpu: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + warnings.warn( + "Distributed is already initialized, cannot initialize twice!" + ) + else: + logger.info( + "distributed init (rank {}): {}".format( + cfg.distributed_training.distributed_rank, + cfg.distributed_training.distributed_init_method, + ) + ) + dist.init_process_group( + backend=cfg.distributed_training.distributed_backend, + init_method=cfg.distributed_training.distributed_init_method, + world_size=cfg.distributed_training.distributed_world_size, + rank=cfg.distributed_training.distributed_rank, + ) + logger.info( + "initialized host {} as rank {}".format( + socket.gethostname(), + cfg.distributed_training.distributed_rank, + ) + ) + + # perform a dummy all-reduce to initialize the NCCL communicator + if torch.cuda.is_available(): + dist.all_reduce(torch.zeros(1).cuda()) + + cfg.distributed_training.distributed_rank = torch.distributed.get_rank() + else: + import torch_xla.core.xla_model as xm + + assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size + cfg.distributed_training.local_rank = xm.get_local_ordinal() + cfg.distributed_training.distributed_rank = xm.get_ordinal() + xm.rendezvous("distributed_init") # wait for all workers + xm.mark_step() + + if is_master(cfg.distributed_training): + logging.getLogger().setLevel(logging.INFO) + else: + logging.getLogger().setLevel(logging.WARNING) + + if cfg.common.model_parallel_size > 1: + try: + from fairseq.model_parallel.megatron.mpu import ( + get_model_parallel_rank, + initialize_model_parallel, + model_parallel_cuda_manual_seed, + ) + except ImportError: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + initialize_model_parallel(cfg.common.model_parallel_size) + model_parallel_cuda_manual_seed(cfg.common.seed) + model_part_number = get_model_parallel_rank() + cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) + return cfg.distributed_training.distributed_rank + + +def distributed_main(i, main, cfg: FairseqConfig, kwargs): + cfg.distributed_training.local_rank = i + if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu: + torch.cuda.set_device(cfg.distributed_training.local_rank) + if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn + cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i + + cfg.distributed_training.distributed_rank = distributed_init(cfg) + + after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None) + if after_distributed_init_fn: + cfg = after_distributed_init_fn(cfg) + + main(cfg, **kwargs) + + +def call_main(cfg: FairseqConfig, main, **kwargs): + if cfg.distributed_training.distributed_init_method is None: + infer_init_method(cfg.distributed_training) + + print("local_rank : {}".format(cfg.distributed_training.local_rank)) + if cfg.distributed_training.distributed_init_method is not None: + # distributed training + if not cfg.distributed_training.distributed_no_spawn: + start_rank = cfg.distributed_training.distributed_rank + cfg.distributed_training.distributed_rank = None # assign automatically + kwargs["start_rank"] = start_rank + torch.multiprocessing.spawn( + fn=distributed_main, + args=(main, cfg, kwargs), + nprocs=min( + torch.cuda.device_count(), + cfg.distributed_training.distributed_world_size, + ), + ) + else: + distributed_main(cfg.distributed_training.local_rank, main, cfg, kwargs) + elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1: + import torch_xla.distributed.xla_multiprocessing as xmp + + torch.multiprocessing.set_sharing_strategy("file_system") + xmp.spawn( + fn=distributed_main, + args=(main, cfg, kwargs), + nprocs=8, # use all 8 TPU cores + ) + else: + # single GPU main + main(cfg, **kwargs) + + +def get_rank(): + return dist.get_rank() + + +def get_world_size(): + return dist.get_world_size() + + +def get_default_group(): + return dist.group.WORLD + + +def all_reduce(tensor, group=None): + if isinstance(group, tuple) and group[0] == "tpu": + import torch_xla.core.xla_model as xm + + return xm.all_reduce("sum", [tensor], groups=group[1]) + else: + if group is None: + group = get_default_group() + return dist.all_reduce(tensor, group=group) + + +def all_gather_list(data, group=None, max_size=16384): + """Gathers arbitrary data from all nodes into a list. + + Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python + data. Note that *data* must be picklable. + + Args: + data (Any): data from the local worker to be gathered on other workers + group (optional): group of the collective + max_size (int, optional): maximum size of the data to be gathered + across workers + """ + rank = get_rank() + world_size = get_world_size() + + buffer_size = max_size * world_size + if ( + not hasattr(all_gather_list, "_buffer") + or all_gather_list._buffer.numel() < buffer_size + ): + all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) + all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() + buffer = all_gather_list._buffer + buffer.zero_() + cpu_buffer = all_gather_list._cpu_buffer + + data = utils.move_to_cpu(data) + enc = pickle.dumps(data) + enc_size = len(enc) + header_size = 4 # size of header that contains the length of the encoded data + size = header_size + enc_size + if size > max_size: + raise ValueError( + "encoded data size ({}) exceeds max_size ({})".format(size, max_size) + ) + + header = struct.pack(">I", enc_size) + cpu_buffer[:size] = torch.ByteTensor(list(header + enc)) + start = rank * max_size + buffer[start : start + size].copy_(cpu_buffer[:size]) + + all_reduce(buffer, group=group) + + buffer = buffer.cpu() + try: + result = [] + for i in range(world_size): + out_buffer = buffer[i * max_size : (i + 1) * max_size] + (enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist())) + if enc_size > 0: + result.append( + pickle.loads( + bytes(out_buffer[header_size : header_size + enc_size].tolist()) + ) + ) + return result + except pickle.UnpicklingError: + raise Exception( + "Unable to unpickle data from other workers. all_gather_list requires all " + "workers to enter the function together, so this error usually indicates " + "that the workers have fallen out of sync somehow. Workers can fall out of " + "sync if one of them runs out of memory, or if there are other conditions " + "in your training script that can cause one worker to finish an epoch " + "while other workers are still iterating over their portions of the data. " + "Try rerunning with --ddp-backend=no_c10d and see if that helps." + ) + + +def all_reduce_dict(data: Mapping[str, Any], device, group=None) -> Dict[str, Any]: + """ + AllReduce a dictionary of values across workers. We separately + reduce items that are already on the device and items on CPU for + better performance. + + Args: + data (Mapping[str, Any]): dictionary of data to all-reduce, but + cannot be a nested dictionary + device (torch.device): device for the reduction + group (optional): group of the collective + """ + data_keys = list(data.keys()) + + # We want to separately reduce items that are already on the + # device and items on CPU for performance reasons. + cpu_data = OrderedDict() + device_data = OrderedDict() + for k in data_keys: + t = data[k] + if not torch.is_tensor(t): + cpu_data[k] = torch.tensor(t, dtype=torch.double) + elif t.device.type != device.type: + cpu_data[k] = t.to(dtype=torch.double) + else: + device_data[k] = t.to(dtype=torch.double) + + def _all_reduce_dict(data: OrderedDict): + if len(data) == 0: + return data + buf = torch.cat([t.view(-1) for t in data.values()]).to(device=device) + all_reduce(buf, group=group) + split_buf = torch.split(buf, [t.numel() for t in data.values()]) + reduced_data = [t.view_as(orig) for t, orig in zip(split_buf, data.values())] + return OrderedDict(zip(data.keys(), reduced_data)) + + cpu_data = _all_reduce_dict(cpu_data) + device_data = _all_reduce_dict(device_data) + + def get_from_stack(key): + if key in cpu_data: + return cpu_data[key] + elif key in device_data: + return device_data[key] + raise KeyError + + return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) diff --git a/fairseq/file_io.py b/fairseq/file_io.py new file mode 100644 index 0000000..d667256 --- /dev/null +++ b/fairseq/file_io.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +from typing import List, Optional + + +try: + from fvcore.common.file_io import PathManager as FVCorePathManager + +except ImportError: + FVCorePathManager = None + + +class PathManager: + """ + Wrapper for insulating OSS I/O (using Python builtin operations) from + fvcore's PathManager abstraction (for transparently handling various + internal backends). + """ + + @staticmethod + def open( + path: str, + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + ): + if FVCorePathManager: + return FVCorePathManager.open( + path=path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + return open( + path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + + @staticmethod + def copy(src_path: str, dst_path: str, overwrite: bool = False) -> bool: + if FVCorePathManager: + return FVCorePathManager.copy( + src_path=src_path, dst_path=dst_path, overwrite=overwrite + ) + return shutil.copyfile(src_path, dst_path) + + @staticmethod + def get_local_path(path: str, **kwargs) -> str: + if FVCorePathManager: + return FVCorePathManager.get_local_path(path, **kwargs) + return path + + @staticmethod + def exists(path: str) -> bool: + if FVCorePathManager: + return FVCorePathManager.exists(path) + return os.path.exists(path) + + @staticmethod + def isfile(path: str) -> bool: + if FVCorePathManager: + return FVCorePathManager.isfile(path) + return os.path.isfile(path) + + @staticmethod + def ls(path: str) -> List[str]: + if FVCorePathManager: + return FVCorePathManager.ls(path) + return os.listdir(path) + + @staticmethod + def mkdirs(path: str) -> None: + if FVCorePathManager: + return FVCorePathManager.mkdirs(path) + os.makedirs(path, exist_ok=True) + + @staticmethod + def rm(path: str) -> None: + if FVCorePathManager: + return FVCorePathManager.rm(path) + os.remove(path) + + @staticmethod + def chmod(path: str, mode: int) -> None: + if "manifold" not in path: + os.chmod(path, mode) + + @staticmethod + def register_handler(handler) -> None: + if FVCorePathManager: + return FVCorePathManager.register_handler(handler=handler) + + @staticmethod + def copy_from_local( + local_path: str, dst_path: str, overwrite: bool = False, **kwargs + ) -> None: + if FVCorePathManager: + return FVCorePathManager.copy_from_local( + local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs + ) + return shutil.copyfile(local_path, dst_path) diff --git a/fairseq/file_utils.py b/fairseq/file_utils.py new file mode 100644 index 0000000..ec6de37 --- /dev/null +++ b/fairseq/file_utils.py @@ -0,0 +1,353 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for working with the local dataset cache. +This file is adapted from `AllenNLP `_. +and `huggingface `_. +""" + +import fnmatch +import json +import logging +import os +import shutil +import tarfile +import tempfile +from functools import partial, wraps +from hashlib import sha256 +from io import open + + +try: + from torch.hub import _get_torch_home + + torch_cache_home = _get_torch_home() +except ImportError: + torch_cache_home = os.path.expanduser( + os.getenv( + "TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch") + ) + ) +default_cache_path = os.path.join(torch_cache_home, "pytorch_fairseq") + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + + PYTORCH_FAIRSEQ_CACHE = Path(os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path)) +except (AttributeError, ImportError): + PYTORCH_FAIRSEQ_CACHE = os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path) + +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "pytorch_model.bin" + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def load_archive_file(archive_file): + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=None) + except EnvironmentError: + logger.info( + "Archive name '{}' was not found in archive name list. " + "We assumed '{}' was a path or URL but couldn't find any file " + "associated to this path or URL.".format( + archive_file, + archive_file, + ) + ) + return None + + if resolved_archive_file == archive_file: + logger.info("loading archive file {}".format(archive_file)) + else: + logger.info( + "loading archive file {} from cache at {}".format( + archive_file, resolved_archive_file + ) + ) + + # Extract archive to temp dir and replace .tar.bz2 if necessary + tempdir = None + if not os.path.isdir(resolved_archive_file): + tempdir = tempfile.mkdtemp() + logger.info( + "extracting archive file {} to temp dir {}".format( + resolved_archive_file, tempdir + ) + ) + ext = os.path.splitext(archive_file)[1][1:] + with tarfile.open(resolved_archive_file, "r:" + ext) as archive: + top_dir = os.path.commonprefix(archive.getnames()) + archive.extractall(tempdir) + os.remove(resolved_archive_file) + shutil.move(os.path.join(tempdir, top_dir), resolved_archive_file) + shutil.rmtree(tempdir) + + return resolved_archive_file + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the URL's, delimited + by a period. + """ + url_bytes = url.encode("utf-8") + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode("utf-8") + etag_hash = sha256(etag_bytes) + filename += "." + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_FAIRSEQ_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + ".json" + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata["url"] + etag = metadata["etag"] + + return url, etag + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_FAIRSEQ_CACHE + if isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ("http", "https", "s3"): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == "": + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError( + "unable to parse {} as a URL or as a local path".format(url_or_filename) + ) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + from botocore.exceptions import ClientError + + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + import boto3 + + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + import boto3 + + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def request_wrap_timeout(func, url): + import requests + + for attempt, timeout in enumerate([10, 20, 40, 60, 60]): + try: + return func(timeout=timeout) + except requests.exceptions.Timeout as e: + logger.warning( + "Request for %s timed-out (attempt %d). Retrying with a timeout of %d secs", + url, + attempt, + timeout, + exc_info=e, + ) + continue + raise RuntimeError(f"Unable to fetch file {url}") + + +def http_get(url, temp_file): + import requests + from tqdm import tqdm + + req = request_wrap_timeout(partial(requests.get, url, stream=True), url) + content_length = req.headers.get("Content-Length") + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_FAIRSEQ_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + try: + import requests + + response = request_wrap_timeout( + partial(requests.head, url, allow_redirects=True), url + ) + if response.status_code != 200: + etag = None + else: + etag = response.headers.get("ETag") + except RuntimeError: + etag = None + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # If we don't have a connection (etag is None) and can't identify the file + # try to get the last downloaded one + if not os.path.exists(cache_path) and etag is None: + matching_files = fnmatch.filter(os.listdir(cache_dir), filename + ".*") + matching_files = list(filter(lambda s: not s.endswith(".json"), matching_files)) + if matching_files: + cache_path = os.path.join(cache_dir, matching_files[-1]) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, "wb") as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + output_string = json.dumps(meta) + meta_file.write(output_string) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path + + +def read_set_from_file(filename): + """ + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + """ + collection = set() + with open(filename, "r", encoding="utf-8") as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path, dot=True, lower=True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py new file mode 100644 index 0000000..3be7078 --- /dev/null +++ b/fairseq/hub_utils.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import copy +import logging +import os +from typing import Any, Dict, Iterator, List + +import torch +from fairseq import utils +from fairseq.data import encoders +from omegaconf import open_dict +from torch import nn + + +logger = logging.getLogger(__name__) + + +def from_pretrained( + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + archive_map=None, + **kwargs +): + from fairseq import checkpoint_utils, file_utils + + if archive_map is not None: + if model_name_or_path in archive_map: + model_name_or_path = archive_map[model_name_or_path] + if data_name_or_path is not None and data_name_or_path in archive_map: + data_name_or_path = archive_map[data_name_or_path] + + # allow archive_map to set default arg_overrides (e.g., tokenizer, bpe) + # for each model + if isinstance(model_name_or_path, dict): + for k, v in model_name_or_path.items(): + if k == "checkpoint_file": + checkpoint_file = v + elif ( + k != "path" + # only set kwargs that don't already have overrides + and k not in kwargs + ): + kwargs[k] = v + model_name_or_path = model_name_or_path["path"] + + model_path = file_utils.load_archive_file(model_name_or_path) + + # convenience hack for loading data and BPE codes from model archive + if data_name_or_path.startswith("."): + kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path)) + else: + kwargs["data"] = file_utils.load_archive_file(data_name_or_path) + for file, arg in { + "code": "bpe_codes", + "bpecodes": "bpe_codes", + "sentencepiece.bpe.model": "sentencepiece_model", + }.items(): + path = os.path.join(model_path, file) + if os.path.exists(path): + kwargs[arg] = path + + if "user_dir" in kwargs: + utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"])) + + models, args, task = checkpoint_utils.load_model_ensemble_and_task( + [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)], + arg_overrides=kwargs, + ) + + return { + "args": args, + "task": task, + "models": models, + } + + +class GeneratorHubInterface(nn.Module): + """ + PyTorch Hub interface for generating sequences from a pre-trained + translation or language model. + """ + + def __init__(self, cfg, task, models): + super().__init__() + self.cfg = cfg + self.task = task + self.models = nn.ModuleList(models) + self.src_dict = task.source_dictionary + self.tgt_dict = task.target_dictionary + + # optimize model for generation + for model in self.models: + model.prepare_for_inference_(cfg) + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + self.align_dict = utils.load_align_dict(cfg.generation.replace_unk) + + self.tokenizer = encoders.build_tokenizer(cfg.tokenizer) + self.bpe = encoders.build_bpe(cfg.bpe) + + self.max_positions = utils.resolve_max_positions( + self.task.max_positions(), *[model.max_positions() for model in models] + ) + + # this is useful for determining the device + self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) + + @property + def device(self): + return self._float_tensor.device + + def translate( + self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs + ) -> List[str]: + return self.sample(sentences, beam, verbose, **kwargs) + + def sample( + self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs + ) -> List[str]: + if isinstance(sentences, str): + return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0] + tokenized_sentences = [self.encode(sentence) for sentence in sentences] + batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs) + return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos] + + def score(self, sentences: List[str], **kwargs): + if isinstance(sentences, str): + return self.score([sentences], **kwargs)[0] + # NOTE: this doesn't support translation tasks currently + tokenized_sentences = [self.encode(sentence) for sentence in sentences] + return [ + hypos[0] + for hypos in self.generate( + tokenized_sentences, score_reference=True, **kwargs + ) + ] + + def generate( + self, + tokenized_sentences: List[torch.LongTensor], + beam: int = 5, + verbose: bool = False, + skip_invalid_size_inputs=False, + inference_step_args=None, + **kwargs + ) -> List[List[Dict[str, torch.Tensor]]]: + if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1: + return self.generate( + tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs + )[0] + + # build generator using current args as well as any kwargs + gen_args = copy.copy(self.cfg) + with open_dict(gen_args): + gen_args.beam = beam + for k, v in kwargs.items(): + setattr(gen_args, k, v) + generator = self.task.build_generator(self.models, gen_args) + + inference_step_args = inference_step_args or {} + results = [] + for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): + batch = utils.apply_to_sample(lambda t: t.to(self.device), batch) + translations = self.task.inference_step( + generator, self.models, batch, **inference_step_args + ) + for id, hypos in zip(batch["id"].tolist(), translations): + results.append((id, hypos)) + + # sort output to match input order + outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] + + if verbose: + + def getarg(name, default): + return getattr(gen_args, name, getattr(self.args, name, default)) + + for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): + src_str_with_unk = self.string(source_tokens) + logger.info("S\t{}".format(src_str_with_unk)) + for hypo in target_hypotheses: + hypo_str = self.decode(hypo["tokens"]) + logger.info("H\t{}\t{}".format(hypo["score"], hypo_str)) + logger.info( + "P\t{}".format( + " ".join( + map( + lambda x: "{:.4f}".format(x), + hypo["positional_scores"].tolist(), + ) + ) + ) + ) + if hypo["alignment"] is not None and getarg( + "print_alignment", False + ): + logger.info( + "A\t{}".format( + " ".join( + [ + "{}-{}".format(src_idx, tgt_idx) + for src_idx, tgt_idx in hypo["alignment"] + ] + ) + ) + ) + return outputs + + def encode(self, sentence: str) -> torch.LongTensor: + sentence = self.tokenize(sentence) + sentence = self.apply_bpe(sentence) + return self.binarize(sentence) + + def decode(self, tokens: torch.LongTensor) -> str: + sentence = self.string(tokens) + sentence = self.remove_bpe(sentence) + return self.detokenize(sentence) + + def tokenize(self, sentence: str) -> str: + if self.tokenizer is not None: + sentence = self.tokenizer.encode(sentence) + return sentence + + def detokenize(self, sentence: str) -> str: + if self.tokenizer is not None: + sentence = self.tokenizer.decode(sentence) + return sentence + + def apply_bpe(self, sentence: str) -> str: + if self.bpe is not None: + sentence = self.bpe.encode(sentence) + return sentence + + def remove_bpe(self, sentence: str) -> str: + if self.bpe is not None: + sentence = self.bpe.decode(sentence) + return sentence + + def binarize(self, sentence: str) -> torch.LongTensor: + return self.src_dict.encode_line(sentence, add_if_not_exist=False).long() + + def string(self, tokens: torch.LongTensor) -> str: + return self.tgt_dict.string(tokens) + + def _build_batches( + self, tokens: List[List[int]], skip_invalid_size_inputs: bool + ) -> Iterator[Dict[str, Any]]: + lengths = torch.LongTensor([t.numel() for t in tokens]) + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.build_dataset_for_inference(tokens, lengths), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, + max_positions=self.max_positions, + ignore_invalid_inputs=skip_invalid_size_inputs, + disable_iterator_cache=True, + ).next_epoch_itr(shuffle=False) + return batch_iterator + + +class BPEHubInterface(object): + """PyTorch Hub interface for Byte-Pair Encoding (BPE).""" + + def __init__(self, bpe, **kwargs): + super().__init__() + args = argparse.Namespace(bpe=bpe, **kwargs) + self.bpe = encoders.build_bpe(args) + assert self.bpe is not None + + def encode(self, sentence: str) -> str: + return self.bpe.encode(sentence) + + def decode(self, sentence: str) -> str: + return self.bpe.decode(sentence) + + +class TokenizerHubInterface(object): + """PyTorch Hub interface for tokenization.""" + + def __init__(self, tokenizer, **kwargs): + super().__init__() + args = argparse.Namespace(tokenizer=tokenizer, **kwargs) + self.tokenizer = encoders.build_tokenizer(args) + assert self.tokenizer is not None + + def encode(self, sentence: str) -> str: + return self.tokenizer.encode(sentence) + + def decode(self, sentence: str) -> str: + return self.tokenizer.decode(sentence) diff --git a/fairseq/incremental_decoding_utils.py b/fairseq/incremental_decoding_utils.py new file mode 100644 index 0000000..b26e6cd --- /dev/null +++ b/fairseq/incremental_decoding_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import uuid +from typing import Dict, Optional + +from torch import Tensor + + +class FairseqIncrementalState(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) + return cls diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py new file mode 100644 index 0000000..4fb0946 --- /dev/null +++ b/fairseq/iterative_refinement_generator.py @@ -0,0 +1,359 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import namedtuple + +import numpy as np +import torch +from fairseq import utils + + +DecoderOut = namedtuple( + "IterativeRefinementDecoderOut", + ["output_tokens", "output_scores", "attn", "step", "max_step", "history"], +) + + +class IterativeRefinementGenerator(object): + def __init__( + self, + tgt_dict, + models=None, + eos_penalty=0.0, + max_iter=10, + max_ratio=2, + beam_size=1, + decoding_format=None, + retain_dropout=False, + adaptive=True, + retain_history=False, + reranking=False, + ): + """ + Generates translations based on iterative refinement. + + Args: + tgt_dict: target dictionary + eos_penalty: if > 0.0, it penalized early-stopping in decoding + max_iter: maximum number of refinement iterations + max_ratio: generate sequences of maximum length ax, where x is the source length + decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'} + retain_dropout: retaining dropout in the inference + adaptive: decoding with early stop + """ + self.bos = tgt_dict.bos() + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() + self.vocab_size = len(tgt_dict) + self.eos_penalty = eos_penalty + self.max_iter = max_iter + self.max_ratio = max_ratio + self.beam_size = beam_size + self.reranking = reranking + self.decoding_format = decoding_format + self.retain_dropout = retain_dropout + self.retain_history = retain_history + self.adaptive = adaptive + self.models = models + + def generate_batched_itr( + self, + data_itr, + maxlen_a=None, + maxlen_b=None, + cuda=False, + timer=None, + prefix_size=0, + ): + """Iterate over a batched dataset and yield individual translations. + + Args: + maxlen_a/b: generate sequences of maximum length ax + b, + where x is the source sentence length. + cuda: use GPU for generation + timer: StopwatchMeter for timing generations. + """ + + for sample in data_itr: + if "net_input" not in sample: + continue + if timer is not None: + timer.start() + with torch.no_grad(): + hypos = self.generate( + self.models, + sample, + prefix_tokens=sample["target"][:, :prefix_size] + if prefix_size > 0 + else None, + ) + if timer is not None: + timer.stop(sample["ntokens"]) + for i, id in enumerate(sample["id"]): + # remove padding + src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad) + ref = utils.strip_pad(sample["target"][i, :], self.pad) + yield id, src, ref, hypos[i] + + @torch.no_grad() + def generate(self, models, sample, prefix_tokens=None, constraints=None): + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the IterativeRefinementGenerator is not supported" + ) + + # TODO: iterative refinement generator does not support ensemble for now. + if not self.retain_dropout: + for model in models: + model.eval() + + model, reranker = models[0], None + if self.reranking: + assert len(models) > 1, "Assuming the last checkpoint is the reranker" + assert ( + self.beam_size > 1 + ), "Reranking requires multiple translation for each example" + + reranker = models[-1] + models = models[:-1] + + if len(models) > 1 and hasattr(model, "enable_ensemble"): + assert model.allow_ensemble, "{} does not support ensembling".format( + model.__class__.__name__ + ) + model.enable_ensemble(models) + + # TODO: better encoder inputs? + src_tokens = sample["net_input"]["src_tokens"] + src_lengths = sample["net_input"]["src_lengths"] + bsz, src_len = src_tokens.size() + + # initialize + encoder_out = model.forward_encoder([src_tokens, src_lengths]) + prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) + + if self.beam_size > 1: + assert ( + model.allow_length_beam + ), "{} does not support decoding with length beam.".format( + model.__class__.__name__ + ) + + # regenerate data based on length-beam + length_beam_order = ( + utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1) + ) + encoder_out = model.encoder.reorder_encoder_out( + encoder_out, length_beam_order + ) + prev_decoder_out = model.regenerate_length_beam( + prev_decoder_out, self.beam_size + ) + bsz = bsz * self.beam_size + + sent_idxs = torch.arange(bsz) + prev_output_tokens = prev_decoder_out.output_tokens.clone() + + if self.retain_history: + prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens]) + + finalized = [[] for _ in range(bsz)] + + def is_a_loop(x, y, s, a): + b, l_x, l_y = x.size(0), x.size(1), y.size(1) + if l_x > l_y: + y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) + s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) + if a is not None: + a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) + elif l_x < l_y: + x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) + return (x == y).all(1), y, s, a + + def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): + cutoff = prev_out_token.ne(self.pad) + tokens = prev_out_token[cutoff] + if prev_out_score is None: + scores, score = None, None + else: + scores = prev_out_score[cutoff] + score = scores.mean() + + if prev_out_attn is None: + hypo_attn, alignment = None, None + else: + hypo_attn = prev_out_attn[cutoff] + alignment = hypo_attn.max(dim=1)[1] + return { + "steps": step, + "tokens": tokens, + "positional_scores": scores, + "score": score, + "hypo_attn": hypo_attn, + "alignment": alignment, + } + + for step in range(self.max_iter + 1): + + decoder_options = { + "eos_penalty": self.eos_penalty, + "max_ratio": self.max_ratio, + "decoding_format": self.decoding_format, + } + prev_decoder_out = prev_decoder_out._replace( + step=step, + max_step=self.max_iter + 1, + ) + + decoder_out = model.forward_decoder( + prev_decoder_out, encoder_out, **decoder_options + ) + + if self.adaptive: + # terminate if there is a loop + terminated, out_tokens, out_scores, out_attn = is_a_loop( + prev_output_tokens, + decoder_out.output_tokens, + decoder_out.output_scores, + decoder_out.attn, + ) + decoder_out = decoder_out._replace( + output_tokens=out_tokens, + output_scores=out_scores, + attn=out_attn, + ) + + else: + terminated = decoder_out.output_tokens.new_zeros( + decoder_out.output_tokens.size(0) + ).bool() + + if step == self.max_iter: # reach last iteration, terminate + terminated.fill_(1) + + # collect finalized sentences + finalized_idxs = sent_idxs[terminated] + finalized_tokens = decoder_out.output_tokens[terminated] + finalized_scores = decoder_out.output_scores[terminated] + finalized_attn = ( + None + if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) + else decoder_out.attn[terminated] + ) + + if self.retain_history: + finalized_history_tokens = [h[terminated] for h in decoder_out.history] + + for i in range(finalized_idxs.size(0)): + finalized[finalized_idxs[i]] = [ + finalized_hypos( + step, + finalized_tokens[i], + finalized_scores[i], + None if finalized_attn is None else finalized_attn[i], + ) + ] + + if self.retain_history: + finalized[finalized_idxs[i]][0]["history"] = [] + for j in range(len(finalized_history_tokens)): + finalized[finalized_idxs[i]][0]["history"].append( + finalized_hypos( + step, finalized_history_tokens[j][i], None, None + ) + ) + + # check if all terminated + if terminated.sum() == terminated.size(0): + break + + # for next step + not_terminated = ~terminated + prev_decoder_out = decoder_out._replace( + output_tokens=decoder_out.output_tokens[not_terminated], + output_scores=decoder_out.output_scores[not_terminated], + attn=decoder_out.attn[not_terminated] + if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0) + else None, + history=[h[not_terminated] for h in decoder_out.history] + if decoder_out.history is not None + else None, + ) + encoder_out = model.encoder.reorder_encoder_out( + encoder_out, not_terminated.nonzero(as_tuple=False).squeeze() + ) + sent_idxs = sent_idxs[not_terminated] + prev_output_tokens = prev_decoder_out.output_tokens.clone() + + if self.beam_size > 1: + if reranker is not None: + finalized = self.rerank( + reranker, finalized, [src_tokens, src_lengths], self.beam_size + ) + + # aggregate information from length beam + finalized = [ + finalized[ + np.argmax( + [ + finalized[self.beam_size * i + j][0]["score"] + for j in range(self.beam_size) + ] + ) + + self.beam_size * i + ] + for i in range(len(finalized) // self.beam_size) + ] + + return finalized + + def rerank(self, reranker, finalized, encoder_input, beam_size): + def rebuild_batch(finalized): + finalized_tokens = [f[0]["tokens"] for f in finalized] + finalized_maxlen = max(f.size(0) for f in finalized_tokens) + final_output_tokens = ( + finalized_tokens[0] + .new_zeros(len(finalized_tokens), finalized_maxlen) + .fill_(self.pad) + ) + for i, f in enumerate(finalized_tokens): + final_output_tokens[i, : f.size(0)] = f + return final_output_tokens + + final_output_tokens = rebuild_batch(finalized) + final_output_tokens[ + :, 0 + ] = self.eos # autoregressive model assumes starting with EOS + + reranker_encoder_out = reranker.encoder(*encoder_input) + length_beam_order = ( + utils.new_arange( + final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1) + ) + .t() + .reshape(-1) + ) + reranker_encoder_out = reranker.encoder.reorder_encoder_out( + reranker_encoder_out, length_beam_order + ) + reranking_scores = reranker.get_normalized_probs( + reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), + True, + None, + ) + reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None]) + reranking_masks = final_output_tokens[:, 1:].ne(self.pad) + reranking_scores = ( + reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1) + ) + reranking_scores = reranking_scores / reranking_masks.sum(1).type_as( + reranking_scores + ) + + for i in range(len(finalized)): + finalized[i][0]["score"] = reranking_scores[i] + + return finalized diff --git a/fairseq/legacy_distributed_data_parallel.py b/fairseq/legacy_distributed_data_parallel.py new file mode 100644 index 0000000..44f87c7 --- /dev/null +++ b/fairseq/legacy_distributed_data_parallel.py @@ -0,0 +1,171 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +A modified version of the legacy DistributedDataParallel module that uses c10d +communication primitives. This version is simpler than the latest PyTorch +version and is useful for debugging. Notably it does not overlap gradient +communication with the backward pass, which makes it slower but more robust +than the PyTorch version. + +This version also supports the *no_sync* context manager, which allows faster +training with `--update-freq`. +""" + +import copy +from collections import OrderedDict +from contextlib import contextmanager + +import torch +from torch import nn +from torch.autograd import Variable + +from . import distributed_utils + + +class LegacyDistributedDataParallel(nn.Module): + """Implements distributed data parallelism at the module level. + + A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`. + This version uses a c10d process group for communication and does not + broadcast buffers. + + Args: + module (~torch.nn.Module): module to be parallelized + world_size (int): number of parallel workers + process_group (optional): the c10d process group to be used for + distributed data all-reduction. If None, the default process group + will be used. + buffer_size (int, optional): number of elements to buffer before + performing all-reduce (default: 256M). + """ + + def __init__(self, module, world_size, process_group=None, buffer_size=2 ** 28): + super().__init__() + + self.module = module + self.world_size = world_size + self.process_group = process_group + + # Never use a bigger buffer than the number of model params + self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters())) + self.buffer = None + + # We can also forcibly accumulate grads locally and only do the + # all-reduce at some later time + self.accumulate_grads = False + + # make per-device lists of parameters + paramlists = OrderedDict() + for param in self.module.parameters(): + device = param.device + if paramlists.get(device) is None: + paramlists[device] = [] + paramlists[device] += [param] + self.per_device_params = list(paramlists.values()) + + def __getstate__(self): + attrs = copy.copy(self.__dict__) + return attrs + + def __setstate__(self, state): + super().__setstate__(state) + + @contextmanager + def no_sync(self): + """A context manager to disable gradient synchronization.""" + old_accumulate_grads = self.accumulate_grads + self.accumulate_grads = True + yield + self.accumulate_grads = old_accumulate_grads + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def all_reduce(self): + """ + This function must be called explicitly after backward to reduce + gradients. There is no automatic hook like c10d. + """ + + def all_reduce_params(params): + buffer = self.buffer + nonzero_buffer = False + if len(params) > 1: + offset = 0 + for p in params: + sz = p.numel() + if p.grad is not None: + buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) + nonzero_buffer = True + else: + buffer[offset : offset + sz].zero_() + offset += sz + else: + # we only have a single grad to all-reduce + p = params[0] + if p.grad is not None: + buffer = p.grad.data + nonzero_buffer = True + elif p.numel() <= self.buffer.numel(): + buffer = buffer[: p.numel()] + buffer.zero_() + else: + buffer = torch.zeros_like(p) + + if nonzero_buffer: + buffer.div_(self.world_size) + + distributed_utils.all_reduce(buffer, self.process_group) + + # copy all-reduced grads back into their original place + offset = 0 + for p in params: + sz = p.numel() + if p.grad is not None: + p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) + else: + p.grad = buffer[offset : offset + sz].view_as(p).clone() + offset += sz + + def reduction_fn(): + # This function only needs to be called once + if self.accumulate_grads: + return + + if self.buffer is None: + self.buffer = next(self.module.parameters()).new(self.buffer_size) + + for params in self.per_device_params: + # All-reduce the gradients in buckets + offset = 0 + buffered_params = [] + for param in params: + if not param.requires_grad: + continue + if param.grad is None: + param.grad = torch.zeros_like(param) + if param.grad.requires_grad: + raise RuntimeError( + "DistributedDataParallel only works " + "with gradients that don't require " + "grad" + ) + sz = param.numel() + if sz > self.buffer.numel(): + # all-reduce big params directly + all_reduce_params([param]) + else: + if offset + sz > self.buffer.numel(): + all_reduce_params(buffered_params) + offset = 0 + buffered_params.clear() + buffered_params.append(param) + offset += sz + + if len(buffered_params) > 0: + all_reduce_params(buffered_params) + + reduction_fn() diff --git a/fairseq/logging/__init__.py b/fairseq/logging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py new file mode 100644 index 0000000..6793ef5 --- /dev/null +++ b/fairseq/logging/meters.py @@ -0,0 +1,291 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import bisect +import time +from collections import OrderedDict +from typing import Dict, Optional + + +try: + import torch + + def type_as(a, b): + if torch.is_tensor(a) and torch.is_tensor(b): + return a.to(b) + else: + return a + + +except ImportError: + torch = None + + def type_as(a, b): + return a + + +try: + import numpy as np +except ImportError: + np = None + + +class Meter(object): + """Base class for Meters.""" + + def __init__(self): + pass + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + pass + + def reset(self): + raise NotImplementedError + + @property + def smoothed_value(self) -> float: + """Smoothed value used for logging.""" + raise NotImplementedError + + +def safe_round(number, ndigits): + if hasattr(number, "__round__"): + return round(number, ndigits) + elif torch is not None and torch.is_tensor(number) and number.numel() == 1: + return safe_round(number.item(), ndigits) + elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"): + return safe_round(number.item(), ndigits) + else: + return number + + +class AverageMeter(Meter): + """Computes and stores the average and current value""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.reset() + + def reset(self): + self.val = None # most recent update + self.sum = 0 # sum from all updates + self.count = 0 # total n from all updates + + def update(self, val, n=1): + if val is not None: + self.val = val + if n > 0: + self.sum = type_as(self.sum, val) + (val * n) + self.count = type_as(self.count, n) + n + + def state_dict(self): + return { + "val": self.val, + "sum": self.sum, + "count": self.count, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.val = state_dict["val"] + self.sum = state_dict["sum"] + self.count = state_dict["count"] + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.sum / self.count if self.count > 0 else self.val + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class TimeMeter(Meter): + """Computes the average occurrence of some event per second""" + + def __init__( + self, + init: int = 0, + n: int = 0, + round: Optional[int] = None, + ): + self.round = round + self.reset(init, n) + + def reset(self, init=0, n=0): + self.init = init + self.start = time.perf_counter() + self.n = n + self.i = 0 + + def update(self, val=1): + self.n = type_as(self.n, val) + val + self.i += 1 + + def state_dict(self): + return { + "init": self.elapsed_time, + "n": self.n, + "round": self.round, + } + + def load_state_dict(self, state_dict): + if "start" in state_dict: + # backwards compatibility for old state_dicts + self.reset(init=state_dict["init"]) + else: + self.reset(init=state_dict["init"], n=state_dict["n"]) + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.n / self.elapsed_time + + @property + def elapsed_time(self): + return self.init + (time.perf_counter() - self.start) + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class StopwatchMeter(Meter): + """Computes the sum/avg duration of some event in seconds""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.sum = 0 + self.n = 0 + self.start_time = None + + def start(self): + self.start_time = time.perf_counter() + + def stop(self, n=1, prehook=None): + if self.start_time is not None: + if prehook is not None: + prehook() + delta = time.perf_counter() - self.start_time + self.sum = self.sum + delta + self.n = type_as(self.n, n) + n + + def reset(self): + self.sum = 0 # cumulative time during which stopwatch was active + self.n = 0 # total n across all start/stop + self.start() + + def state_dict(self): + return { + "sum": self.sum, + "n": self.n, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.sum = state_dict["sum"] + self.n = state_dict["n"] + self.start_time = None + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.sum / self.n if self.n > 0 else self.sum + + @property + def elapsed_time(self): + if self.start_time is None: + return 0.0 + return time.perf_counter() - self.start_time + + @property + def smoothed_value(self) -> float: + val = self.avg if self.sum > 0 else self.elapsed_time + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class MetersDict(OrderedDict): + """A sorted dictionary of :class:`Meters`. + + Meters are sorted according to a priority that is given when the + meter is first added to the dictionary. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.priorities = [] + + def __setitem__(self, key, value): + assert key not in self, "MetersDict doesn't support reassignment" + priority, value = value + bisect.insort(self.priorities, (priority, len(self.priorities), key)) + super().__setitem__(key, value) + for _, _, key in self.priorities: # reorder dict to match priorities + self.move_to_end(key) + + def add_meter(self, key, meter, priority): + self.__setitem__(key, (priority, meter)) + + def state_dict(self): + return [ + (pri, key, self[key].__class__.__name__, self[key].state_dict()) + for pri, _, key in self.priorities + # can't serialize DerivedMeter instances + if not isinstance(self[key], MetersDict._DerivedMeter) + ] + + def load_state_dict(self, state_dict): + self.clear() + self.priorities.clear() + for pri, key, meter_cls, meter_state in state_dict: + meter = globals()[meter_cls]() + meter.load_state_dict(meter_state) + self.add_meter(key, meter, pri) + + def get_smoothed_value(self, key: str) -> float: + """Get a single smoothed value.""" + meter = self[key] + if isinstance(meter, MetersDict._DerivedMeter): + return meter.fn(self) + else: + return meter.smoothed_value + + def get_smoothed_values(self) -> Dict[str, float]: + """Get all smoothed values.""" + return OrderedDict( + [ + (key, self.get_smoothed_value(key)) + for key in self.keys() + if not key.startswith("_") + ] + ) + + def reset(self): + """Reset Meter instances.""" + for meter in self.values(): + if isinstance(meter, MetersDict._DerivedMeter): + continue + meter.reset() + + class _DerivedMeter(Meter): + """A Meter whose values are derived from other Meters.""" + + def __init__(self, fn): + self.fn = fn + + def reset(self): + pass diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py new file mode 100644 index 0000000..7b56e31 --- /dev/null +++ b/fairseq/logging/metrics.py @@ -0,0 +1,288 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +A standalone module for aggregating metrics. + +Metrics can be logged from anywhere using the `log_*` functions defined +in this module. The logged values will be aggregated dynamically based +on the aggregation context in which the logging occurs. See the +:func:`aggregate` context manager for more details. +""" + +import contextlib +import time +import uuid +from collections import OrderedDict, defaultdict +from typing import Callable, Dict, List, Optional + +from .meters import * + + +# Aggregation contexts are considered "active" when inside the scope +# created by the :func:`aggregate` context manager. +_aggregators = OrderedDict() +_active_aggregators = OrderedDict() +_active_aggregators_cnt = defaultdict(lambda: 0) + + +def reset() -> None: + """Reset all metrics aggregators.""" + _aggregators.clear() + _active_aggregators.clear() + _active_aggregators_cnt.clear() + + # The "default" aggregator observes all logged values. + _aggregators["default"] = MetersDict() + _active_aggregators["default"] = _aggregators["default"] + _active_aggregators_cnt["default"] = 1 + + +reset() + + +@contextlib.contextmanager +def aggregate(name: Optional[str] = None, new_root: bool = False): + """Context manager to aggregate metrics under a given name. + + Aggregations can be nested. If *new_root* is ``False``, then logged + metrics will be recorded along the entire stack of nested + aggregators, including a global "default" aggregator. If *new_root* + is ``True``, then this aggregator will be the root of a new + aggregation stack, thus bypassing any parent aggregators. + + Note that aggregation contexts are uniquely identified by their + *name* (e.g., train, valid). Creating a context with an existing + name will reuse the corresponding :class:`MetersDict` instance. + If no name is given, then a temporary aggregator will be created. + + Usage:: + + with metrics.aggregate("train"): + for step, batch in enumerate(epoch): + with metrics.aggregate("train_inner") as agg: + metrics.log_scalar("loss", get_loss(batch)) + if step % log_interval == 0: + print(agg.get_smoothed_value("loss")) + agg.reset() + print(metrics.get_smoothed_values("train")["loss"]) + + Args: + name (str): name of the aggregation. Defaults to a + random/temporary name if not given explicitly. + new_root (bool): make this aggregation the root of a new + aggregation stack. + """ + if name is None: + # generate a temporary name + name = str(uuid.uuid4()) + assert name not in _aggregators + agg = MetersDict() + else: + assert name != "default" + agg = _aggregators.setdefault(name, MetersDict()) + + if new_root: + backup_aggregators = _active_aggregators.copy() + _active_aggregators.clear() + backup_aggregators_cnt = _active_aggregators_cnt.copy() + _active_aggregators_cnt.clear() + + _active_aggregators[name] = agg + _active_aggregators_cnt[name] += 1 + + yield agg + + _active_aggregators_cnt[name] -= 1 + if _active_aggregators_cnt[name] == 0 and name in _active_aggregators: + del _active_aggregators[name] + + if new_root: + _active_aggregators.clear() + _active_aggregators.update(backup_aggregators) + _active_aggregators_cnt.clear() + _active_aggregators_cnt.update(backup_aggregators_cnt) + + +def get_active_aggregators() -> List[MetersDict]: + return list(_active_aggregators.values()) + + +def log_scalar( + key: str, + value: float, + weight: float = 1, + priority: int = 10, + round: Optional[int] = None, +): + """Log a scalar value. + + Args: + key (str): name of the field to log + value (float): value to log + weight (float): weight that this value contributes to the average. + A weight of 0 will always log the latest value. + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, AverageMeter(round=round), priority) + agg[key].update(value, weight) + + +def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20): + """Log a scalar value derived from other meters. + + Args: + key (str): name of the field to log + fn (Callable[[MetersDict], float]): function that takes a single + argument *meters* and returns the derived value + priority (int): smaller values are logged earlier in the output + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, MetersDict._DerivedMeter(fn), priority) + + +def log_speed( + key: str, + value: float, + priority: int = 30, + round: Optional[int] = None, +): + """Log the rate of some quantity per second. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, TimeMeter(round=round), priority) + agg[key].reset() # reset meter on the first call + else: + agg[key].update(value) + + +def log_start_time(key: str, priority: int = 40, round: Optional[int] = None): + """Log the duration of some event in seconds. + + The duration will be computed once :func:`log_stop_time` is called. + + Args: + key (str): name of the field to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, StopwatchMeter(round=round), priority) + agg[key].start() + + +def log_stop_time(key: str, weight: float = 0.0, prehook=None): + """Log the duration of some event in seconds. + + The duration will be computed since :func:`log_start_time` was called. + Set weight > 0 to report the average time instead of the sum. + + Args: + key (str): name of the field to log + weight (float): weight that this time contributes to the average + prehook (function, no arguments): will be called before the timer + is stopped. For example, use prehook=torch.cuda.synchronize to + make sure all gpu operations are done before timer is stopped. + """ + for agg in get_active_aggregators(): + if key in agg: + agg[key].stop(weight, prehook) + + +def log_custom( + new_meter_fn: Callable[[], Meter], + key: str, + *args, + priority: int = 50, + **kwargs, +): + """Log using a custom Meter. + + Any extra *args* or *kwargs* will be passed through to the Meter's + *update* method. + + Args: + new_meter_fn (Callable[[], Meter]): function that returns a new + Meter instance + key (str): name of the field to log + priority (int): smaller values are logged earlier in the output + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, new_meter_fn(), priority) + agg[key].update(*args, **kwargs) + + +def reset_meter(name: str, key: str) -> None: + """Reset Meter instance aggregated under a given *name* and *key*.""" + meter = get_meter(name, key) + if meter is not None: + meter.reset() + + +def reset_meters(name: str) -> None: + """Reset Meter instances aggregated under a given *name*.""" + meters = get_meters(name) + if meters is not None: + meters.reset() + + +def get_meter(name: str, key: str) -> Meter: + """Get a single Meter instance aggregated under *name* and *key*. + + Returns: + Meter or None if no metrics have been logged under *name* and *key*. + """ + if name not in _aggregators: + return None + return _aggregators[name].get(key, None) + + +def get_meters(name: str) -> MetersDict: + """Get Meter instances aggregated under a given *name*. + + Returns: + MetersDict or None if no metrics have been logged under *name*. + """ + return _aggregators.get(name, None) + + +def get_smoothed_value(name: str, key: str) -> float: + """Get a single smoothed value. + + Raises: + KeyError: if no metrics have been logged under *name* and *key*. + """ + return _aggregators[name].get_smoothed_value(key) + + +def get_smoothed_values(name: str) -> Dict[str, float]: + """Get smoothed values aggregated under a given *name*. + + Raises: + KeyError: if no metrics have been logged under *name*. + """ + return _aggregators[name].get_smoothed_values() + + +def state_dict(): + return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()]) + + +def load_state_dict(state_dict): + for name, agg_state in state_dict.items(): + _aggregators[name] = MetersDict() + _aggregators[name].load_state_dict(agg_state) diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py new file mode 100644 index 0000000..63e5394 --- /dev/null +++ b/fairseq/logging/progress_bar.py @@ -0,0 +1,355 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Wrapper around various loggers and progress bars (e.g., tqdm). +""" + +import atexit +import json +import logging +import os +import sys +from collections import OrderedDict +from contextlib import contextmanager +from numbers import Number +from typing import Optional + +import torch + +from .meters import AverageMeter, StopwatchMeter, TimeMeter + + +logger = logging.getLogger(__name__) + + +def progress_bar( + iterator, + log_format: Optional[str] = None, + log_interval: int = 100, + epoch: Optional[int] = None, + prefix: Optional[str] = None, + tensorboard_logdir: Optional[str] = None, + default_log_format: str = "tqdm", +): + if log_format is None: + log_format = default_log_format + if log_format == "tqdm" and not sys.stderr.isatty(): + log_format = "simple" + + if log_format == "json": + bar = JsonProgressBar(iterator, epoch, prefix, log_interval) + elif log_format == "none": + bar = NoopProgressBar(iterator, epoch, prefix) + elif log_format == "simple": + bar = SimpleProgressBar(iterator, epoch, prefix, log_interval) + elif log_format == "tqdm": + bar = TqdmProgressBar(iterator, epoch, prefix) + else: + raise ValueError("Unknown log format: {}".format(log_format)) + + if tensorboard_logdir: + try: + # [FB only] custom wrapper for TensorBoard + import palaas # noqa + from .fb_tbmf_wrapper import FbTbmfWrapper + + bar = FbTbmfWrapper(bar, log_interval) + except ImportError: + bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) + + return bar + + +def build_progress_bar( + args, + iterator, + epoch: Optional[int] = None, + prefix: Optional[str] = None, + default: str = "tqdm", + no_progress_bar: str = "none", +): + """Legacy wrapper that takes an argparse.Namespace.""" + if getattr(args, "no_progress_bar", False): + default = no_progress_bar + if getattr(args, "distributed_rank", 0) == 0: + tensorboard_logdir = getattr(args, "tensorboard_logdir", None) + else: + tensorboard_logdir = None + return progress_bar( + iterator, + log_format=args.log_format, + log_interval=args.log_interval, + epoch=epoch, + prefix=prefix, + tensorboard_logdir=tensorboard_logdir, + default_log_format=default, + ) + + +def format_stat(stat): + if isinstance(stat, Number): + stat = "{:g}".format(stat) + elif isinstance(stat, AverageMeter): + stat = "{:.3f}".format(stat.avg) + elif isinstance(stat, TimeMeter): + stat = "{:g}".format(round(stat.avg)) + elif isinstance(stat, StopwatchMeter): + stat = "{:g}".format(round(stat.sum)) + elif torch.is_tensor(stat): + stat = stat.tolist() + return stat + + +class BaseProgressBar(object): + """Abstract class for progress bars.""" + + def __init__(self, iterable, epoch=None, prefix=None): + self.iterable = iterable + self.n = getattr(iterable, "n", 0) + self.epoch = epoch + self.prefix = "" + if epoch is not None: + self.prefix += "epoch {:03d}".format(epoch) + if prefix is not None: + self.prefix += " | {}".format(prefix) + + def __len__(self): + return len(self.iterable) + + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def __iter__(self): + raise NotImplementedError + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + raise NotImplementedError + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + raise NotImplementedError + + def _str_commas(self, stats): + return ", ".join(key + "=" + stats[key].strip() for key in stats.keys()) + + def _str_pipes(self, stats): + return " | ".join(key + " " + stats[key].strip() for key in stats.keys()) + + def _format_stats(self, stats): + postfix = OrderedDict(stats) + # Preprocess stats according to datatype + for key in postfix.keys(): + postfix[key] = str(format_stat(postfix[key])) + return postfix + + +@contextmanager +def rename_logger(logger, new_name): + old_name = logger.name + if new_name is not None: + logger.name = new_name + yield logger + logger.name = old_name + + +class JsonProgressBar(BaseProgressBar): + """Log output in JSON format.""" + + def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): + super().__init__(iterable, epoch, prefix) + self.log_interval = log_interval + self.i = None + self.size = None + + def __iter__(self): + self.size = len(self.iterable) + for i, obj in enumerate(self.iterable, start=self.n): + self.i = i + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + step = step or self.i or 0 + if step > 0 and self.log_interval is not None and step % self.log_interval == 0: + update = ( + self.epoch - 1 + (self.i + 1) / float(self.size) + if self.epoch is not None + else None + ) + stats = self._format_stats(stats, epoch=self.epoch, update=update) + with rename_logger(logger, tag): + logger.info(json.dumps(stats)) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self.stats = stats + if tag is not None: + self.stats = OrderedDict( + [(tag + "_" + k, v) for k, v in self.stats.items()] + ) + stats = self._format_stats(self.stats, epoch=self.epoch) + with rename_logger(logger, tag): + logger.info(json.dumps(stats)) + + def _format_stats(self, stats, epoch=None, update=None): + postfix = OrderedDict() + if epoch is not None: + postfix["epoch"] = epoch + if update is not None: + postfix["update"] = round(update, 3) + # Preprocess stats according to datatype + for key in stats.keys(): + postfix[key] = format_stat(stats[key]) + return postfix + + +class NoopProgressBar(BaseProgressBar): + """No logging.""" + + def __init__(self, iterable, epoch=None, prefix=None): + super().__init__(iterable, epoch, prefix) + + def __iter__(self): + for obj in self.iterable: + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + pass + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + pass + + +class SimpleProgressBar(BaseProgressBar): + """A minimal logger for non-TTY environments.""" + + def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): + super().__init__(iterable, epoch, prefix) + self.log_interval = log_interval + self.i = None + self.size = None + + def __iter__(self): + self.size = len(self.iterable) + for i, obj in enumerate(self.iterable, start=self.n): + self.i = i + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + step = step or self.i or 0 + if step > 0 and self.log_interval is not None and step % self.log_interval == 0: + stats = self._format_stats(stats) + postfix = self._str_commas(stats) + with rename_logger(logger, tag): + logger.info( + "{}: {:5d} / {:d} {}".format( + self.prefix, self.i + 1, self.size, postfix + ) + ) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + postfix = self._str_pipes(self._format_stats(stats)) + with rename_logger(logger, tag): + logger.info("{} | {}".format(self.prefix, postfix)) + + +class TqdmProgressBar(BaseProgressBar): + """Log to tqdm.""" + + def __init__(self, iterable, epoch=None, prefix=None): + super().__init__(iterable, epoch, prefix) + from tqdm import tqdm + + self.tqdm = tqdm( + iterable, + self.prefix, + leave=False, + disable=(logger.getEffectiveLevel() > logging.INFO), + ) + + def __iter__(self): + return iter(self.tqdm) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + self.tqdm.set_postfix(self._format_stats(stats), refresh=False) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + postfix = self._str_pipes(self._format_stats(stats)) + with rename_logger(logger, tag): + logger.info("{} | {}".format(self.prefix, postfix)) + + +try: + _tensorboard_writers = {} + from tensorboardX import SummaryWriter +except ImportError: + SummaryWriter = None + + +def _close_writers(): + for w in _tensorboard_writers.values(): + w.close() + + +atexit.register(_close_writers) + + +class TensorboardProgressBarWrapper(BaseProgressBar): + """Log to tensorboard.""" + + def __init__(self, wrapped_bar, tensorboard_logdir): + self.wrapped_bar = wrapped_bar + self.tensorboard_logdir = tensorboard_logdir + + if SummaryWriter is None: + logger.warning( + "tensorboard not found, please install with: pip install tensorboardX" + ) + + def _writer(self, key): + if SummaryWriter is None: + return None + _writers = _tensorboard_writers + if key not in _writers: + _writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key)) + _writers[key].add_text("sys.argv", " ".join(sys.argv)) + return _writers[key] + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to tensorboard.""" + self._log_to_tensorboard(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self._log_to_tensorboard(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def _log_to_tensorboard(self, stats, tag=None, step=None): + writer = self._writer(tag or "") + if writer is None: + return + if step is None: + step = stats["num_updates"] + for key in stats.keys() - {"num_updates"}: + if isinstance(stats[key], AverageMeter): + writer.add_scalar(key, stats[key].val, step) + elif isinstance(stats[key], Number): + writer.add_scalar(key, stats[key], step) + writer.flush() diff --git a/fairseq/model_parallel/__init__.py b/fairseq/model_parallel/__init__.py new file mode 100644 index 0000000..69f2168 --- /dev/null +++ b/fairseq/model_parallel/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import criterions, models, modules # noqa diff --git a/fairseq/model_parallel/criterions/__init__.py b/fairseq/model_parallel/criterions/__init__.py new file mode 100644 index 0000000..6239b50 --- /dev/null +++ b/fairseq/model_parallel/criterions/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the criterions/ directory +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.model_parallel.criterions." + module) diff --git a/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py b/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py new file mode 100644 index 0000000..35c50ee --- /dev/null +++ b/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion + + +try: + from fairseq.model_parallel.megatron.mpu.cross_entropy import ( + vocab_parallel_cross_entropy, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +@register_criterion("vocab_parallel_cross_entropy") +class VocabParallelCrossEntropyCriterion(FairseqCriterion): + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + target = sample["target"] + + loss = vocab_parallel_cross_entropy(net_output[0].float(), target) + loss = (loss * (target != self.padding_idx)).sum() + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py new file mode 100644 index 0000000..cf83685 --- /dev/null +++ b/fairseq/model_parallel/megatron_trainer.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Train a network across multiple GPUs. +""" + +from fairseq import distributed_utils +from fairseq.trainer import Trainer +from fairseq.dataclass.configs import FairseqConfig + + +try: + from fairseq.model_parallel.megatron.mpu import ( + get_data_parallel_group, + get_data_parallel_rank, + get_data_parallel_world_size, + get_model_parallel_group, + get_model_parallel_src_rank, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +class MegatronTrainer(Trainer): + """Main class for model parallel with data parallel training.""" + + def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs): + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + super().__init__(cfg, task, model, criterion, **kwargs) + + @property + def data_parallel_world_size(self): + return get_data_parallel_world_size() + + @property + def data_parallel_process_group(self): + return get_data_parallel_group() + + @property + def data_parallel_rank(self): + return get_data_parallel_rank() + + @property + def is_data_parallel_master(self): + return get_model_parallel_src_rank() == 0 + + def clip_grad_norm(self, clip_norm): + def _aggregate_model_parallel_grad_norm(total_norm): + total_norm = total_norm ** 2 + distributed_utils.all_reduce(total_norm, group=get_model_parallel_group()) + total_norm = total_norm ** 0.5 + return total_norm + + return self.optimizer.clip_grad_norm( + clip_norm, + aggregate_norm_fn=_aggregate_model_parallel_grad_norm, + ) diff --git a/fairseq/model_parallel/models/__init__.py b/fairseq/model_parallel/models/__init__.py new file mode 100644 index 0000000..3532479 --- /dev/null +++ b/fairseq/model_parallel/models/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("fairseq.model_parallel.models." + model_name) diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py new file mode 100644 index 0000000..117827c --- /dev/null +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .model import * # noqa diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py new file mode 100644 index 0000000..eb81ded --- /dev/null +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py @@ -0,0 +1,600 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import options, utils +from fairseq.modules import ( + AdaptiveSoftmax, + LayerNorm, + MultiheadAttention, + PositionalEmbedding, +) + + +EncoderOut = namedtuple( + "TransformerEncoderOut", + [ + "encoder_out", # T x B x C + "encoder_padding_mask", # B x T + "encoder_embedding", # B x T x C + "encoder_states", # List[T x B x C] + ], +) + + +class TransformerEncoderEmbedding(nn.Module): + """ Encoder Embedding + Positional Embedding """ + + def __init__(self, args, embed_tokens): + super().__init__() + self.dropout = args.dropout + self.max_source_positions = args.max_source_positions + self.embed_tokens = embed_tokens + if isinstance(embed_tokens, nn.ModuleList): + self.padding_idx = embed_tokens[0].padding_idx + embed_dim = sum(e.embedding_dim for e in embed_tokens) + else: + self.padding_idx = embed_tokens.padding_idx + embed_dim = embed_tokens.embedding_dim + self.embed_scale = math.sqrt(embed_dim) + self.embed_positions = ( + PositionalEmbedding( + args.max_source_positions, + embed_dim, + self.padding_idx, + learned=args.encoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + def forward(self, input): + # embed tokens and positions + src_tokens = input[0] + prev_output_tokens = input[2] + if isinstance(self.embed_tokens, nn.ModuleList): + x_embed_list = [] + for embed_tokens_part in self.embed_tokens: + x_embed_list.append(embed_tokens_part(src_tokens)) + + embedded = torch.cat(x_embed_list, dim=-1) + else: + embedded = self.embed_tokens(src_tokens) + x = embed = self.embed_scale * embedded + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + if self.layernorm_embedding: + x = self.layernorm_embedding(x) + x = F.dropout(x, p=self.dropout, training=self.training) + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # compute padding mask + encoder_padding_mask = src_tokens.eq(self.padding_idx) + return (x, encoder_padding_mask, prev_output_tokens) + + +class TransformerEncoderLayerNorm(nn.Module): + """ + Layer norm at the the end of all encoder layers if + args.encoder_enormalize_before = True + """ + + def __init__(self, args, embed_dim): + super().__init__() + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward(self, input): + x = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + if self.layer_norm: + x = self.layer_norm(x) + # keeping track of the incremental_state is not supported yet + return (x, encoder_padding_mask, prev_output_tokens) + + +class TransformerDecoderEmbedding(nn.Module): + """ Decoder Embedding + Positional Embedding """ + + def __init__(self, args, embed_tokens): + super().__init__() + self.dropout = args.dropout + self.share_input_output_embed = args.share_decoder_input_output_embed + input_embed_dim = ( + sum(e.embedding_dim for e in embed_tokens) + if isinstance(embed_tokens, nn.ModuleList) + else embed_tokens.embedding_dim + ) + embed_dim = args.decoder_embed_dim + self.output_embed_dim = args.decoder_output_dim + + padding_idx = ( + embed_tokens[0].padding_idx + if isinstance(embed_tokens, nn.ModuleList) + else embed_tokens.padding_idx + ) + self.max_target_positions = args.max_target_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + args.max_target_positions, + embed_dim, + padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + + def forward(self, input): + mt_task = False + if isinstance(input, tuple): + if len(input) == 3: + encoder_out = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + incremental_state = None # Hardcoding to avoid passing of None objects + mt_task = True + else: + # HACK for now, need to fix (TODO sidgoyal) + prev_output_tokens = input[0] + # discard "src_lengths" + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + else: + prev_output_tokens = input + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + positions = ( + self.embed_positions( + prev_output_tokens, + incremental_state=incremental_state, + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + + if isinstance(self.embed_tokens, nn.ModuleList): + x_embed_list = [] + for embed_tokens_part in self.embed_tokens: + x_embed_list.append(embed_tokens_part(prev_output_tokens)) + + x = self.embed_scale * torch.cat(x_embed_list, dim=-1) + else: + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + if mt_task: + return (x, encoder_out, encoder_padding_mask) + return x + + +class TransformerDecoderOutputLayer(nn.Module): + def __init__(self, args, embed_tokens, dictionary): + super().__init__() + self.share_input_output_embed = args.share_decoder_input_output_embed + self.embed_tokens = embed_tokens + self.output_embed_dim = args.decoder_output_dim + embed_dim = args.decoder_embed_dim + + self.project_out_dim = ( + Linear(embed_dim, self.output_embed_dim, bias=False) + if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights + else None + ) + self.adaptive_softmax = None + if args.adaptive_softmax_cutoff is not None: + assert not isinstance(embed_tokens, nn.ModuleList) + self.adaptive_softmax = AdaptiveSoftmax( + len(dictionary), + self.output_embed_dim, + options.eval_str_list(args.adaptive_softmax_cutoff, type=int), + dropout=args.adaptive_softmax_dropout, + adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, + factor=args.adaptive_softmax_factor, + tie_proj=args.tie_adaptive_proj, + ) + elif not self.share_input_output_embed: + self.embed_tokens = nn.Parameter( + torch.Tensor(len(dictionary), self.output_embed_dim) + ) + nn.init.normal_( + self.embed_tokens, mean=0, std=self.output_embed_dim ** -0.5 + ) + + if args.decoder_normalize_before and not getattr( + args, "no_decoder_final_norm", False + ): + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward(self, input, apply_final_proj=True): + if isinstance(input, tuple): + x = input[0] + else: + x = input + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + if apply_final_proj: + x = self.output_layer(x) + return x + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + if self.share_input_output_embed: + if isinstance(self.embed_tokens, nn.ModuleList): + output = None + for i, emb in enumerate(self.embed_tokens): + sidx = i * emb.embedding_dim + eidx = (i + 1) * emb.embedding_dim + if output is None: + output = F.linear(features[:, :, sidx:eidx], emb.weight) + else: + output += F.linear(features[:, :, sidx:eidx], emb.weight) + + return output + else: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_tokens) + else: + return features + + +class TransformerEncoderLayer(nn.Module): + """Encoder layer block. + In the original paper each operation (multi-head attention or FFN) is + postprocessed with: `dropout -> add residual -> layernorm`. In the + tensor2tensor code they suggest that learning is more robust when + preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.encoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + + def __init__(self, args): + super().__init__() + self.embed_dim = args.encoder_embed_dim + self.self_attn = MultiheadAttention( + self.embed_dim, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + ) + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.dropout = args.dropout + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, "activation_fn", "relu") + ) + self.activation_dropout = getattr(args, "activation_dropout", 0) + if self.activation_dropout == 0: + # for backwards compatibility with models that use args.relu_dropout + self.activation_dropout = getattr(args, "relu_dropout", 0) + self.normalize_before = args.encoder_normalize_before + self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) + self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) + + def upgrade_state_dict_named(self, state_dict, name): + """ + Rename layer norm states from `...layer_norms.0.weight` to + `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to + `...final_layer_norm.weight` + """ + layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layer_norms.{}.{}".format(name, old, m) + if k in state_dict: + state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] + del state_dict[k] + + def forward(self, input): + """ + Args: + input (Tuple): + input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + input[1] (ByteTensor/FloatTensor): encoder padding mask - + binary ByteTensor of shape `(batch, src_len)` where padding elements + are indicated by ``1``. + input[2] (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing) + Returns: + output (Tuple): + output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` + output[1] (ByteTensor/FloatTensor): encoder padding mask + output[2] (LongTensor): previous decoder outputs + """ + x = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + residual = x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) + x, _ = self.self_attn( + query=x, key=x, value=x, key_padding_mask=encoder_padding_mask + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) + + residual = x + x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) + return (x, encoder_padding_mask, prev_output_tokens) + + def maybe_layer_norm(self, layer_norm, x, before=False, after=False): + assert before ^ after + if after ^ self.normalize_before: + return layer_norm(x) + else: + return x + + +class TransformerDecoderLayer(nn.Module): + """Decoder layer block. + + In the original paper each operation (multi-head attention, encoder + attention or FFN) is postprocessed with: `dropout -> add residual -> + layernorm`. In the tensor2tensor code they suggest that learning is more + robust when preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.decoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): + super().__init__() + self.embed_dim = args.decoder_embed_dim + self.self_attn = MultiheadAttention( + embed_dim=self.embed_dim, + num_heads=args.decoder_attention_heads, + dropout=args.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=True, + ) + self.dropout = args.dropout + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, "activation_fn", "relu") + ) + self.activation_dropout = getattr(args, "activation_dropout", 0) + if self.activation_dropout == 0: + # for backwards compatibility with models that use args.relu_dropout + self.activation_dropout = getattr(args, "relu_dropout", 0) + self.normalize_before = args.decoder_normalize_before + + # use layerNorm rather than FusedLayerNorm for exporting. + # char_inputs can be used to determint this. + # TODO remove this once we update apex with the fix + export = getattr(args, "char_inputs", False) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = MultiheadAttention( + self.embed_dim, + args.decoder_attention_heads, + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), + dropout=args.attention_dropout, + encoder_decoder_attention=True, + ) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) + self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=export) + self.need_attn = True + + self.onnx_trace = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def forward(self, input): + """ + Args: + input (Tuple): + input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)` + input[2] (ByteTensor/FloatTensor): encoder padding mask - + binary ByteTensor of shape `(batch, src_len)` where padding elements + are indicated by ``1``. + Returns: + output (Tuple): + output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` + output[1] (ByteTensor/FloatTensor): encoder padding mask + output[2] (LongTensor): previous decoder outputs + """ + # Note: incremental state is not yet supported + mt_task = False + if isinstance(input, tuple): + x = input[0] + encoder_out = input[1] + encoder_padding_mask = input[2] + incremental_state = None + mt_task = True + else: + x = input + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + if incremental_state is None: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + # TODO: add back prev_self_attn_state, prev_attn_state, + # self_attn_padding_mask + prev_self_attn_state = None + prev_attn_state = None + self_attn_padding_mask = None + + residual = x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) + if prev_self_attn_state is not None: + if incremental_state is None: + incremental_state = {} + prev_key, prev_value = prev_self_attn_state + saved_state = {"prev_key": prev_key, "prev_value": prev_value} + self.self_attn._set_input_buffer(incremental_state, saved_state) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) + + if self.encoder_attn is not None: + residual = x + x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) + if prev_attn_state is not None: + if incremental_state is None: + incremental_state = {} + prev_key, prev_value = prev_attn_state + saved_state = {"prev_key": prev_key, "prev_value": prev_value} + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=(not self.training and self.need_attn), + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) + + residual = x + x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) + + if mt_task: + return (x, encoder_out, encoder_padding_mask) + return x + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + if self._future_mask.size(0) < dim: + self._future_mask = torch.triu( + utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def maybe_layer_norm(self, layer_norm, x, before=False, after=False): + assert before ^ after + if after ^ self.normalize_before: + return layer_norm(x) + else: + return x + + def make_generation_fast_(self, need_attn=False, **kwargs): + self.need_attn = need_attn + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py new file mode 100644 index 0000000..7873611 --- /dev/null +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -0,0 +1,711 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import ( + Embedding, + TransformerDecoderEmbedding, + TransformerDecoderLayer, + TransformerDecoderOutputLayer, + TransformerEncoderEmbedding, + TransformerEncoderLayer, + TransformerEncoderLayerNorm, +) +from fairseq.models import ( + BaseFairseqModel, + FairseqDecoder, + FairseqEncoder, + register_model, + register_model_architecture, +) +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.transformer import ( + base_architecture, + transformer_iwslt_de_en, + transformer_wmt_en_de_big, +) +from fairseq.modules import SinusoidalPositionalEmbedding + + +logger = logging.getLogger(__name__) + + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 + + +@register_model("pipeline_parallel_transformer") +class PipelineParallelTransformerModel(BaseFairseqModel): + def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): + try: + from fairscale.nn import Pipe + except ImportError: + raise ImportError("Please install fairscale with: pip install fairscale") + super().__init__() + assert isinstance(encoder, FairseqEncoder) + assert isinstance(decoder, FairseqDecoder) + encoder_module_list = ( + [encoder.embedding_layer] + + list(encoder.encoder_layers) + + [encoder.final_layer_norm] + ) + self.num_encoder_modules = len(encoder_module_list) + decoder_module_list = ( + [decoder.embedding_layer] + + list(decoder.decoder_layers) + + [decoder.decoder_output_layer] + ) + self.num_decoder_modules = len(decoder_module_list) + module_list = encoder_module_list + decoder_module_list + self.devices = devices + self.model = Pipe( + nn.Sequential(*module_list), + balance=balance, + devices=devices, + chunks=chunks, + checkpoint=checkpoint, + ) + self.encoder_max_positions = self.max_positions_helper( + encoder.embedding_layer, "max_source_positions" + ) + self.decoder_max_positions = self.max_positions_helper( + decoder.embedding_layer, "max_target_positions" + ) + self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None) + # Note: To be populated during inference + self.encoder = None + self.decoder = None + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + if self.training: + input_lst = [src_tokens, src_lengths, prev_output_tokens] + input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) + return self.model(input) + else: + assert self.encoder is not None and self.decoder is not None, ( + "encoder and decoder need to be initialized by " + + "calling the `prepare_for_inference_()` method" + ) + encoder_output_tuple = self.encoder(input) + return self.decoder(encoder_output_tuple) + + def prepare_for_inference_(self, cfg): + if self.encoder is not None and self.decoder is not None: + logger.info("Encoder and Decoder already initialized") + return + encoder_module_list = [] + decoder_module_list = [] + module_count = 0 + for partition in self.model.partitions: + for module in partition: + if module_count < self.num_encoder_modules: + encoder_module_list.append(module) + else: + decoder_module_list.append(module) + module_count += 1 + self.model = None + self.encoder = TransformerEncoder(cfg.distributed_training, None, None, encoder_module_list) + self.decoder = TransformerDecoder( + cfg.distributed_training, None, None, decoder_module_list=decoder_module_list + ) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--activation-fn', + choices=utils.get_available_activation_fns(), + help='activation function to use') + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--attention-dropout', type=float, metavar='D', + help='dropout probability for attention weights') + parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', + help='dropout probability after activation in FFN.') + parser.add_argument('--encoder-embed-path', type=str, metavar='STR', + help='path to pre-trained encoder embedding') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', + help='encoder embedding dimension for FFN') + parser.add_argument('--encoder-layers', type=int, metavar='N', + help='num encoder layers') + parser.add_argument('--encoder-attention-heads', type=int, metavar='N', + help='num encoder attention heads') + parser.add_argument('--encoder-normalize-before', action='store_true', + help='apply layernorm before each encoder block') + parser.add_argument('--encoder-learned-pos', action='store_true', + help='use learned positional embeddings in the encoder') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', + help='decoder embedding dimension for FFN') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='num decoder layers') + parser.add_argument('--decoder-attention-heads', type=int, metavar='N', + help='num decoder attention heads') + parser.add_argument('--decoder-learned-pos', action='store_true', + help='use learned positional embeddings in the decoder') + parser.add_argument('--decoder-normalize-before', action='store_true', + help='apply layernorm before each decoder block') + parser.add_argument('--share-decoder-input-output-embed', action='store_true', + help='share decoder input and output embeddings') + parser.add_argument('--share-all-embeddings', action='store_true', + help='share encoder, decoder and output embeddings' + ' (requires shared dictionary and embed dim)') + parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', + help='if set, disables positional embeddings (outside self attention)') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion'), + parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', + help='sets adaptive softmax dropout for the tail projections') + parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1, + help='Number of embedding layer chunks (enables more even distribution' + 'of optimizer states across data parallel nodes' + 'when using optimizer state sharding and' + 'a big embedding vocabulary)') + # fmt: on + + @classmethod + def build_model_base(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if not hasattr(args, "max_source_positions"): + args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS + if not hasattr(args, "max_target_positions"): + args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1): + assert embed_dim % num_embed_chunks == 0, ( + f"Number of embedding chunks = {num_embed_chunks} should be " + + f"divisible by the embedding dimension = {embed_dim}" + ) + assert path is None or num_embed_chunks == 1, ( + "Loading embedding from a path with number of embedding chunks > 1" + + " is not yet supported" + ) + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + # if provided, load from preloaded dictionaries + if path: + emb = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + else: + embed_chunk_dim = embed_dim // num_embed_chunks + emb = nn.ModuleList() + for i in range(num_embed_chunks): + emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx)) + return emb + + num_embed_chunks = args.num_embedding_chunks + if args.share_all_embeddings: + if src_dict != tgt_dict: + raise ValueError("--share-all-embeddings requires a joined dictionary") + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + encoder_embed_tokens = build_embedding( + src_dict, + args.encoder_embed_dim, + args.encoder_embed_path, + num_embed_chunks, + ) + decoder_embed_tokens = encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + assert args.share_decoder_input_output_embed or num_embed_chunks == 1, ( + "Not sharing decoder I/O embeddings is not yet supported with number of " + + "embedding chunks > 1" + ) + encoder_embed_tokens = build_embedding( + src_dict, + args.encoder_embed_dim, + args.encoder_embed_path, + num_embed_chunks, + ) + decoder_embed_tokens = build_embedding( + tgt_dict, + args.decoder_embed_dim, + args.decoder_embed_path, + num_embed_chunks, + ) + + encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) + decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + return (encoder, decoder) + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return TransformerEncoder(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return TransformerDecoder(args, tgt_dict, embed_tokens) + + @classmethod + def build_model(cls, args, task): + encoder, decoder = cls.build_model_base(args, task) + return PipelineParallelTransformerModel( + encoder=encoder, + decoder=decoder, + balance=utils.eval_str_list(args.pipeline_balance, type=int), + devices=utils.eval_str_list(args.pipeline_devices, type=int), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return (self.encoder_max_positions, self.decoder_max_positions) + + def max_positions_helper( + self, embedding_layer, max_positions_field="max_source_positions" + ): + """Maximum input length supported by the encoder or decoder.""" + if embedding_layer.embed_positions is None: + return getattr(embedding_layer, max_positions_field) + return min( + getattr(embedding_layer, max_positions_field), + embedding_layer.embed_positions.max_positions, + ) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + + if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: + if sample is not None: + assert "target" in sample + target = sample["target"] + else: + target = None + out = self.adaptive_softmax.get_log_prob(net_output, target=target) + return out.exp_() if not log_probs else out + + # A Pipe() module returns a tuple of tensors as the output. + # In this case, the tuple has one element - the output tensor of logits + logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0] + if log_probs: + return utils.log_softmax(logits, dim=-1, onnx_trace=False) + else: + return utils.softmax(logits, dim=-1, onnx_trace=False) + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder_max_positions + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + """Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`nn.Module`. Compared with that method + this additionally "upgrades" *state_dicts* from old checkpoints. + """ + self.upgrade_state_dict(state_dict) + is_regular_transformer = not any("model.partitions" in k for k in state_dict) + if is_regular_transformer: + state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict) + return super().load_state_dict(state_dict, strict) + + def convert_to_pipeline_parallel_state_dict(self, state_dict): + new_state_dict = self.state_dict() + encoder_layer_idx = 0 + decoder_layer_idx = 0 + encoder_key_suffixes = [ + "self_attn.k_proj.weight", + "self_attn.k_proj.bias", + "self_attn.v_proj.weight", + "self_attn.v_proj.bias", + "self_attn.q_proj.weight", + "self_attn.q_proj.bias", + "self_attn.out_proj.weight", + "self_attn.out_proj.bias", + "self_attn_layer_norm.weight", + "self_attn_layer_norm.bias", + "fc1.weight", + "fc1.bias", + "fc2.weight", + "fc2.bias", + "final_layer_norm.weight", + "final_layer_norm.bias", + ] + decoder_key_suffixes = [ + "self_attn.k_proj.weight", + "self_attn.k_proj.bias", + "self_attn.v_proj.weight", + "self_attn.v_proj.bias", + "self_attn.q_proj.weight", + "self_attn.q_proj.bias", + "self_attn.out_proj.weight", + "self_attn.out_proj.bias", + "self_attn_layer_norm.weight", + "self_attn_layer_norm.bias", + "encoder_attn.k_proj.weight", + "encoder_attn.k_proj.bias", + "encoder_attn.v_proj.weight", + "encoder_attn.v_proj.bias", + "encoder_attn.q_proj.weight", + "encoder_attn.q_proj.bias", + "encoder_attn.out_proj.weight", + "encoder_attn.out_proj.bias", + "encoder_attn_layer_norm.weight", + "encoder_attn_layer_norm.bias", + "fc1.weight", + "fc1.bias", + "fc2.weight", + "fc2.bias", + "final_layer_norm.weight", + "final_layer_norm.bias", + ] + for pid, partition in enumerate(self.model.partitions): + logger.info(f"Begin Partition {pid}") + for mid, module in enumerate(partition): + # fmt: off + if isinstance(module, TransformerEncoderEmbedding): + new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight'] + new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['encoder.embed_positions._float_tensor'] + if isinstance(module, TransformerEncoderLayer): + for suffix in encoder_key_suffixes: + new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}'] + encoder_layer_idx += 1 + if isinstance(module, TransformerDecoderLayer): + for suffix in decoder_key_suffixes: + new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}'] + decoder_layer_idx += 1 + if isinstance(module, TransformerEncoderLayerNorm): + if 'encoder.layer_norm.weight' in state_dict: + new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight'] + new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias'] + if isinstance(module, TransformerDecoderEmbedding): + new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight'] + new_state_dict[f'model.partitions.{pid}.{mid}.embed_positions._float_tensor'] = state_dict['decoder.embed_positions._float_tensor'] + if isinstance(module, TransformerDecoderOutputLayer): + new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight'] + # fmt: on + return new_state_dict + + +class TransformerEncoder(FairseqEncoder): + """ + Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`TransformerEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_tokens (torch.nn.Embedding): input embedding + """ + + def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + try: + from fairscale.nn import Pipe + except ImportError: + raise ImportError("Please install fairscale with: pip install fairscale") + self.use_pipeline = encoder_module_list is not None + if not self.use_pipeline: + self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) + self.encoder_layers = nn.Sequential(*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]) + if isinstance(embed_tokens, nn.ModuleList): + emb_dim = sum(e.embedding_dim for e in embed_tokens) + else: + emb_dim = embed_tokens.embedding_dim + self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim) + else: + encoder_balance = utils.eval_str_list( + args.pipeline_encoder_balance, type=int + ) + encoder_devices = utils.eval_str_list( + args.pipeline_encoder_devices, type=int + ) + assert sum(encoder_balance) == len(encoder_module_list), ( + f"Sum of encoder_balance={encoder_balance} is not equal " + + f"to num_encoder_modules={len(encoder_module_list)}" + ) + self.model = Pipe( + module=nn.Sequential(*encoder_module_list), + balance=encoder_balance, + devices=encoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + + def forward(self, src_tokens, src_lengths): + """ + Args: + input_tuple( + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + ) + + Returns: + output_tuple( + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - prev_output_tokens + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + ) + """ + dummy_prev_output_tokens = torch.zeros( + 1, dtype=src_tokens.dtype, device=src_tokens.device + ) + input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens) + if self.use_pipeline: + input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) + encoder_out = self.model(input_tuple) + else: + encoder_embed_output_tuple = self.embedding_layer(input_tuple) + encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple) + encoder_out = self.final_layer_norm(encoder_layers_output) + # first element is the encoder output + # second element is the encoder padding mask + # the remaining elements of EncoderOut are not computed by + # the PipelineParallelTransformer + return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None) + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + if encoder_out.encoder_out is not None: + encoder_out = encoder_out._replace( + encoder_out=encoder_out.encoder_out.index_select(1, new_order) + ) + if encoder_out.encoder_padding_mask is not None: + encoder_out = encoder_out._replace( + encoder_padding_mask=encoder_out.encoder_padding_mask.index_select( + 0, new_order + ) + ) + if encoder_out.encoder_embedding is not None: + encoder_out = encoder_out._replace( + encoder_embedding=encoder_out.encoder_embedding.index_select( + 0, new_order + ) + ) + if encoder_out.encoder_states is not None: + for idx, state in enumerate(encoder_out.encoder_states): + encoder_out.encoder_states[idx] = state.index_select(1, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + if self.embedding_layer.embed_positions is None: + return self.embedding_layer.max_source_positions + return min( + self.embedding_layer.max_source_positions, + self.embedding_layer.embed_positions.max_positions, + ) + + +class TransformerDecoder(FairseqDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + decoder_module_list=None, + ): + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + try: + from fairscale.nn import Pipe + except ImportError: + raise ImportError("Please install fairscale with: pip install fairscale") + self.use_pipeline = decoder_module_list is not None + if not self.use_pipeline: + self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) + self.decoder_layers = nn.Sequential(*[ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ]) + self.decoder_output_layer = TransformerDecoderOutputLayer( + args, embed_tokens, dictionary + ) + else: + decoder_balance = utils.eval_str_list( + args.pipeline_decoder_balance, type=int + ) + decoder_devices = utils.eval_str_list( + args.pipeline_decoder_devices, type=int + ) + assert sum(decoder_balance) == len(decoder_module_list), ( + f"Sum of decoder_balance={decoder_balance} is not equal " + + f"to num_decoder_modules={len(decoder_module_list)}" + ) + self.model = Pipe( + module=nn.Sequential(*decoder_module_list), + balance=decoder_balance, + devices=decoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + + def forward( + self, + prev_output_tokens, + encoder_out=None, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + input_tuple = ( + encoder_out.encoder_out, + encoder_out.encoder_padding_mask, + prev_output_tokens, + ) + if self.use_pipeline: + input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) + return (self.model(input_tuple),) + else: + embed_layer_output = self.embedding_layer(input_tuple) + state = self.decoder_layers(embed_layer_output) + return (self.decoder_output_layer(state),) + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + if self.share_input_output_embed: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_out) + else: + return features + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embedding_layer.embed_positions is None: + return self.embedding_layer.max_target_positions + return min( + self.embedding_layer.max_target_positions, + self.embedding_layer.embed_positions.max_positions, + ) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): + weights_key = "{}.embed_positions.weights".format(name) + if weights_key in state_dict: + del state_dict[weights_key] + state_dict[ + "{}.embed_positions._float_tensor".format(name) + ] = torch.FloatTensor(1) + + for i in range(len(self.layers)): + # update layer norms + layer_norm_map = { + "0": "self_attn_layer_norm", + "1": "encoder_attn_layer_norm", + "2": "final_layer_norm", + } + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) + if k in state_dict: + state_dict[ + "{}.layers.{}.{}.{}".format(name, i, new, m) + ] = state_dict[k] + del state_dict[k] + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + + return state_dict + + +@register_model_architecture( + "pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel" +) +def transformer_iwslt_de_en_dist(args): + transformer_iwslt_de_en(args) + + +@register_model_architecture( + "pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel" +) +def transformer_wmt_en_de_big_dist(args): + transformer_wmt_en_de_big(args) diff --git a/fairseq/model_parallel/models/transformer.py b/fairseq/model_parallel/models/transformer.py new file mode 100644 index 0000000..4f34645 --- /dev/null +++ b/fairseq/model_parallel/models/transformer.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch.nn as nn +import torch.nn.functional as F +from fairseq.model_parallel.modules import ( + ModelParallelTransformerDecoderLayer, + ModelParallelTransformerEncoderLayer, +) +from fairseq.models import register_model +from fairseq.models.transformer import ( + TransformerDecoder, + TransformerEncoder, + TransformerModel, +) + + +try: + from fairseq.model_parallel.megatron.mpu import ( + copy_to_model_parallel_region, + gather_from_model_parallel_region, + VocabParallelEmbedding, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +logger = logging.getLogger(__name__) + + +@register_model("model_parallel_transformer") +class ModelParallelTransformerModel(TransformerModel): + """ + Model parallel Transformer model. + """ + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + dictionary.pad_to_multiple_(args.model_parallel_size * 8) + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + + def _vocab_init(tensor, **kwargs): + nn.init.normal_(tensor, mean=0, std=num_embeddings ** -0.5) + nn.init.constant_(tensor[1], 0) + + emb = VocabParallelEmbedding( + num_embeddings, embed_dim, padding_idx, init_method=_vocab_init + ) + # if provided, load from preloaded dictionaries + if path: + raise NotImplementedError( + "Loading of embedding from path is not supported for model parallel" + ) + return emb + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return ModelParallelTransformerEncoder(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return ModelParallelTransformerDecoder( + args, + tgt_dict, + embed_tokens, + no_encoder_attn=getattr(args, "no_cross_attention", False), + ) + + +class ModelParallelTransformerEncoder(TransformerEncoder): + """ + Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`ModelParallelTransformerEncoderLayer`. + """ + + def build_encoder_layer(self, args): + return ModelParallelTransformerEncoderLayer(args) + + +class ModelParallelTransformerDecoder(TransformerDecoder): + """ + Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`ModelParallelTransformerDecoderLayer`. + """ + + def build_decoder_layer(self, args, no_encoder_attn=False): + return ModelParallelTransformerDecoderLayer(args, no_encoder_attn) + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if not self.share_input_output_embed: + raise NotImplementedError( + "Model parallel training currently requires --share-decoder-input-output-embed" + ) + + features = copy_to_model_parallel_region(features) + + # project back to size of vocabulary + x = self.output_projection(features) + + if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy": + x = gather_from_model_parallel_region(x).contiguous() + return x diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py new file mode 100644 index 0000000..dc52f6e --- /dev/null +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -0,0 +1,174 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer_lm import TransformerLanguageModel + + +try: + from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +DEFAULT_MAX_TARGET_POSITIONS = 1024 + + +@register_model("model_parallel_transformer_lm") +class ModelParallelTransformerLanguageModel(TransformerLanguageModel): + + @staticmethod + def add_args(parser): + TransformerLanguageModel.add_args(parser) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + + # make sure all arguments are present in older models + base_lm_architecture(args) + + task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) + + if args.character_embeddings: + raise NotImplementedError( + "Character embeddings is not supported for model parallel" + ) + elif args.adaptive_input: + raise NotImplementedError( + "Adaptive input is not supported for model parallel" + ) + else: + embed_tokens = cls.build_embedding( + args, task.source_dictionary, args.decoder_input_dim + ) + + decoder = ModelParallelTransformerDecoder( + args, + task.target_dictionary, + embed_tokens, + no_encoder_attn=True, + ) + return cls(decoder) + + @staticmethod + def add_args(parser): + TransformerLanguageModel.add_args(parser) + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + def _vocab_init(tensor, **kwargs): + nn.init.normal_(tensor, mean=0, std=embed_dim ** -0.5) + nn.init.constant_(tensor[1], 0) + + embed_tokens = VocabParallelEmbedding( + len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init + ) + return embed_tokens + + +def base_lm_architecture(args): + # backward compatibility for older model checkpoints + if hasattr(args, "no_tie_adaptive_proj"): + # previous models defined --no-tie-adaptive-proj, so use the existence of + # that option to determine if this is an "old" model checkpoint + args.no_decoder_final_norm = True # old models always set this to True + if args.no_tie_adaptive_proj is False: + args.tie_adaptive_proj = True + if hasattr(args, "decoder_final_norm"): + args.no_decoder_final_norm = not args.decoder_final_norm + + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.relu_dropout = getattr(args, "relu_dropout", 0.0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + # Model training is not stable without this + args.decoder_normalize_before = True + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.character_embeddings = getattr(args, "character_embeddings", False) + args.character_filters = getattr( + args, + "character_filters", + "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", + ) + args.character_embedding_dim = getattr(args, "character_embedding_dim", 4) + args.char_embedder_highway_layers = getattr(args, "char_embedder_highway_layers", 2) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0.0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0.0) + args.add_bos_token = getattr(args, "add_bos_token", False) + + +@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron") +def transformer_lm_megatron(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4) + args.decoder_layers = getattr(args, "decoder_layers", 72) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture( + "model_parallel_transformer_lm", "transformer_lm_megatron_11b" +) +def transformer_lm_megatron_11b(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 6) + args.decoder_layers = getattr(args, "decoder_layers", 72) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) diff --git a/fairseq/model_parallel/modules/__init__.py b/fairseq/model_parallel/modules/__init__.py new file mode 100644 index 0000000..fb45b3c --- /dev/null +++ b/fairseq/model_parallel/modules/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +from .multihead_attention import ModelParallelMultiheadAttention +from .transformer_layer import ( + ModelParallelTransformerEncoderLayer, + ModelParallelTransformerDecoderLayer, +) +from .transformer_sentence_encoder_layer import ( + ModelParallelTransformerSentenceEncoderLayer, +) +from .transformer_sentence_encoder import ModelParallelTransformerSentenceEncoder + +__all__ = [ + "ModelParallelMultiheadAttention", + "ModelParallelTransformerEncoderLayer", + "ModelParallelTransformerDecoderLayer", + "ModelParallelTransformerSentenceEncoder", + "ModelParallelTransformerSentenceEncoderLayer", +] diff --git a/fairseq/model_parallel/modules/multihead_attention.py b/fairseq/model_parallel/modules/multihead_attention.py new file mode 100644 index 0000000..4164bf9 --- /dev/null +++ b/fairseq/model_parallel/modules/multihead_attention.py @@ -0,0 +1,352 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout +from torch import Tensor, nn + + +try: + from fairseq.model_parallel.megatron.mpu import ( + get_cuda_rng_tracker, + get_model_parallel_world_size, + ColumnParallelLinear, + RowParallelLinear, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +@with_incremental_state +class ModelParallelMultiheadAttention(nn.Module): + """Model parallel Multi-headed attention. + This performs the Multi-headed attention over multiple gpus. + + See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + self_attention=False, + encoder_decoder_attention=False, + ): + super().__init__() + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.model_parallel_size = get_model_parallel_world_size() + + self.num_heads_partition = num_heads // self.model_parallel_size + assert ( + self.num_heads_partition * self.model_parallel_size == num_heads + ), "Number of heads must be divisible by model parallel size" + + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert ( + not self.self_attention or self.qkv_same_dim + ), "Self-attention requires query, key and value to be of the same size" + + self.k_proj = ColumnParallelLinear( + self.kdim, embed_dim, bias=bias, gather_output=False + ) + self.v_proj = ColumnParallelLinear( + self.vdim, embed_dim, bias=bias, gather_output=False + ) + self.q_proj = ColumnParallelLinear( + embed_dim, embed_dim, bias=bias, gather_output=False + ) + self.out_proj = RowParallelLinear( + embed_dim, embed_dim, bias=bias, input_is_parallel=True + ) + + self.tpu = False + + def prepare_for_tpu_(self, **kwargs): + self.tpu = True + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + **unused_kwargs, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + """ + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads_partition, self.head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads_partition, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads_partition, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads_partition, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view( + bsz * self.num_heads_partition, -1, self.head_dim + ) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view( + bsz * self.num_heads_partition, -1, self.head_dim + ) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = ( + ModelParallelMultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + ) + + saved_state["prev_key"] = k.view( + bsz, self.num_heads_partition, -1, self.head_dim + ) + saved_state["prev_value"] = v.view( + bsz, self.num_heads_partition, -1, self.head_dim + ) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + + assert list(attn_weights.size()) == [ + bsz * self.num_heads_partition, + tgt_len, + src_len, + ] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view( + bsz, self.num_heads_partition, tgt_len, src_len + ) + if not self.tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view( + bsz * self.num_heads_partition, tgt_len, src_len + ) + + attn_weights_float = utils.softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + + with get_cuda_rng_tracker().fork(): + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [ + bsz * self.num_heads_partition, + tgt_len, + self.head_dim, + ] + embed_dim_partition = embed_dim // self.model_parallel_size + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim_partition) + attn = self.out_proj(attn) + # return attn_weights None to keep the return type same as single gpu multihead attention + # This will be deprecated. + attn_weights: Optional[Tensor] = None + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + + filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1)) + if prev_key_padding_mask.is_cuda: + filler = filler.cuda() + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + elif key_padding_mask is not None: + filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1)) + if key_padding_mask.is_cuda: + filler = filler.cuda() + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def reorder_incremental_state( + self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + if input_buffer[k] is not None: + input_buffer[k] = input_buffer[k].index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) diff --git a/fairseq/model_parallel/modules/transformer_layer.py b/fairseq/model_parallel/modules/transformer_layer.py new file mode 100644 index 0000000..7ab53c6 --- /dev/null +++ b/fairseq/model_parallel/modules/transformer_layer.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.model_parallel.modules import ModelParallelMultiheadAttention +from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer + + +try: + from fairseq.model_parallel.megatron.mpu import ( + ColumnParallelLinear, + RowParallelLinear, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer): + """Encoder layer block over multiple gpus. + + See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. + """ + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + if q_noise > 0: + raise NotImplementedError + return ColumnParallelLinear(input_dim, output_dim, gather_output=False) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + if q_noise > 0: + raise NotImplementedError + return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) + + def build_self_attention(self, embed_dim, args, **unused_kwargs): + return ModelParallelMultiheadAttention( + embed_dim, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + ) + + +class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer): + """Decoder layer block. + + See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. + """ + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + if q_noise > 0: + raise NotImplementedError + return ColumnParallelLinear(input_dim, output_dim, gather_output=False) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + if q_noise > 0: + raise NotImplementedError + return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) + + def build_self_attention(self, embed_dim, args, **unused_kwargs): + return ModelParallelMultiheadAttention( + embed_dim=embed_dim, + num_heads=args.decoder_attention_heads, + dropout=args.attention_dropout, + self_attention=not getattr(args, "cross_self_attention", False), + ) + + def build_encoder_attention(self, embed_dim, args, **unused_kwargs): + return ModelParallelMultiheadAttention( + embed_dim=embed_dim, + num_heads=args.decoder_attention_heads, + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), + dropout=args.attention_dropout, + encoder_decoder_attention=True, + ) diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder.py b/fairseq/model_parallel/modules/transformer_sentence_encoder.py new file mode 100644 index 0000000..a5d50a3 --- /dev/null +++ b/fairseq/model_parallel/modules/transformer_sentence_encoder.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import random +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq.model_parallel.modules import ModelParallelTransformerSentenceEncoderLayer +from fairseq.modules import ( + LayerNorm, + MultiheadAttention, + PositionalEmbedding, + TransformerSentenceEncoder, +) + + +try: + from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +class ModelParallelTransformerSentenceEncoder(TransformerSentenceEncoder): + """ + Implementation for a Model Parallel Bi-directional Transformer based + Sentence Encoder used in BERT/XLM style pre-trained models. + """ + + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) + + def build_transformer_sentence_encoder_layer( + self, + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + activation_fn, + export, + **unused, + ): + return ModelParallelTransformerSentenceEncoderLayer( + embedding_dim=embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + ) diff --git a/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py new file mode 100644 index 0000000..e10bf52 --- /dev/null +++ b/fairseq/model_parallel/modules/transformer_sentence_encoder_layer.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.model_parallel.modules import ModelParallelMultiheadAttention +from fairseq.modules import TransformerSentenceEncoderLayer + + +try: + from fairseq.model_parallel.megatron.mpu import ( + ColumnParallelLinear, + RowParallelLinear, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +class ModelParallelTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): + """ + Implements a Model Parallel Transformer Encoder Layer used in + BERT/XLM style pre-trained models. + """ + + def build_fc1(self, input_dim, output_dim, **unused): + return ColumnParallelLinear(input_dim, output_dim, gather_output=False) + + def build_fc2(self, input_dim, output_dim, **unused): + return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) + + def build_self_attention( + self, + embed_dim, + num_attention_heads, + dropout, + **kwargs, + ): + return ModelParallelMultiheadAttention( + embed_dim, num_attention_heads, dropout=dropout, self_attention=True + ) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = self.dropout_module(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = residual + x + return x, None diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py new file mode 100644 index 0000000..caed11e --- /dev/null +++ b/fairseq/models/__init__.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import argparse +import importlib +import os + +import fairseq +from fairseq.dataclass import FairseqDataclass +from omegaconf import DictConfig, OmegaConf + +from .composite_encoder import CompositeEncoder +from .distributed_fairseq_model import DistributedFairseqModel +from .fairseq_decoder import FairseqDecoder +from .fairseq_encoder import FairseqEncoder +from .fairseq_incremental_decoder import FairseqIncrementalDecoder +from .fairseq_model import ( + BaseFairseqModel, + FairseqEncoderDecoderModel, + FairseqEncoderModel, + FairseqLanguageModel, + FairseqModel, + FairseqMultiModel, +) + + +MODEL_REGISTRY = {} +MODEL_DATACLASS_REGISTRY = {} +ARCH_MODEL_REGISTRY = {} +ARCH_MODEL_NAME_REGISTRY = {} +ARCH_MODEL_INV_REGISTRY = {} +ARCH_CONFIG_REGISTRY = {} + + +__all__ = [ + "BaseFairseqModel", + "DistributedFairseqModel", + "FairseqDecoder", + "FairseqEncoder", + "FairseqEncoderDecoderModel", + "FairseqEncoderModel", + "FairseqIncrementalDecoder", + "FairseqModel", +] + + +def build_model(cfg: DictConfig, task): + if isinstance(cfg, DictConfig): + if cfg._name == 'wav2vec_mtls1': + cfg._name = 'unispeech' + return ARCH_MODEL_REGISTRY[cfg._name].build_model(cfg, task) + if cfg.arch == 'wav2vec_mtls1': + cfg.arch = 'unispeech' + return ARCH_MODEL_REGISTRY[cfg.arch].build_model(cfg, task) + + +def register_model(name, dataclass=None): + """ + New model types can be added to fairseq with the :func:`register_model` + function decorator. + + For example:: + + @register_model('lstm') + class LSTM(FairseqEncoderDecoderModel): + (...) + + .. note:: All models must implement the :class:`BaseFairseqModel` interface. + Typically you will extend :class:`FairseqEncoderDecoderModel` for + sequence-to-sequence tasks or :class:`FairseqLanguageModel` for + language modeling tasks. + + Args: + name (str): the name of the model + """ + + def register_model_cls(cls): + if name in MODEL_REGISTRY: + raise ValueError("Cannot register duplicate model ({})".format(name)) + if not issubclass(cls, BaseFairseqModel): + raise ValueError( + "Model ({}: {}) must extend BaseFairseqModel".format(name, cls.__name__) + ) + MODEL_REGISTRY[name] = cls + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + if dataclass is not None: + MODEL_DATACLASS_REGISTRY[name] = dataclass + return cls + + return register_model_cls + + +def register_model_architecture(model_name, arch_name): + """ + New model architectures can be added to fairseq with the + :func:`register_model_architecture` function decorator. After registration, + model architectures can be selected with the ``--arch`` command-line + argument. + + For example:: + + @register_model_architecture('lstm', 'lstm_luong_wmt_en_de') + def lstm_luong_wmt_en_de(cfg): + args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000) + (...) + + The decorated function should take a single argument *cfg*, which is a + :class:`omegaconf.DictConfig`. The decorated function should modify these + arguments in-place to match the desired architecture. + + Args: + model_name (str): the name of the Model (Model must already be + registered) + arch_name (str): the name of the model architecture (``--arch``) + """ + + def arch_override_from_yaml(args, arch): + root_dir = os.path.dirname(os.path.dirname(fairseq.__file__)) + yaml_path = os.path.join(root_dir, "config/model/{}.yaml".format(arch)) + if not os.path.exists(yaml_path): + raise RuntimeError(f"yaml file {yaml_path} does not exist!") + arch_cfg = OmegaConf.load(yaml_path) + for k, v in arch_cfg.items(): + setattr(args, k, getattr(args, k, v)) + + def register_model_arch_fn(fn): + if model_name not in MODEL_REGISTRY: + raise ValueError( + "Cannot register model architecture for unknown model type ({})".format( + model_name + ) + ) + if arch_name in ARCH_MODEL_REGISTRY: + raise ValueError( + "Cannot register duplicate model architecture ({})".format(arch_name) + ) + if not callable(fn): + raise ValueError( + "Model architecture must be callable ({})".format(arch_name) + ) + ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] + ARCH_MODEL_NAME_REGISTRY[arch_name] = model_name + ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) + if type(fn) is type and issubclass(fn, BaseFairseqModel): + # for model classes migrated with hydra + # in this case, we are using this decorator directly on model class since + # we do not need arch overriding functions. + ARCH_CONFIG_REGISTRY[arch_name] = lambda args: arch_override_from_yaml( + args, arch=arch_name + ) + else: + ARCH_CONFIG_REGISTRY[arch_name] = fn + return fn + + return register_model_arch_fn + + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("fairseq.models." + model_name) + + # extra `model_parser` for sphinx + if model_name in MODEL_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_archs = parser.add_argument_group("Named architectures") + group_archs.add_argument( + "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name] + ) + group_args = parser.add_argument_group("Additional command-line arguments") + MODEL_REGISTRY[model_name].add_args(group_args) + globals()[model_name + "_parser"] = parser diff --git a/fairseq/models/composite_encoder.py b/fairseq/models/composite_encoder.py new file mode 100644 index 0000000..4e20fe3 --- /dev/null +++ b/fairseq/models/composite_encoder.py @@ -0,0 +1,57 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .fairseq_encoder import FairseqEncoder + + +class CompositeEncoder(FairseqEncoder): + """ + A wrapper around a dictionary of :class:`FairseqEncoder` objects. + + We run forward on each encoder and return a dictionary of outputs. The first + encoder's dictionary is used for initialization. + + Args: + encoders (dict): a dictionary of :class:`FairseqEncoder` objects. + """ + + def __init__(self, encoders): + super().__init__(next(iter(encoders.values())).dictionary) + self.encoders = encoders + for key in self.encoders: + self.add_module(key, self.encoders[key]) + + def forward(self, src_tokens, src_lengths): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of shape + `(batch)` + + Returns: + dict: + the outputs from each Encoder + """ + encoder_out = {} + for key in self.encoders: + encoder_out[key] = self.encoders[key](src_tokens, src_lengths) + return encoder_out + + def reorder_encoder_out(self, encoder_out, new_order): + """Reorder encoder output according to new_order.""" + for key in self.encoders: + encoder_out[key] = self.encoders[key].reorder_encoder_out( + encoder_out[key], new_order + ) + return encoder_out + + def max_positions(self): + return min(self.encoders[key].max_positions() for key in self.encoders) + + def upgrade_state_dict(self, state_dict): + for key in self.encoders: + self.encoders[key].upgrade_state_dict(state_dict) + return state_dict diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py new file mode 100644 index 0000000..ece10c6 --- /dev/null +++ b/fairseq/models/distributed_fairseq_model.py @@ -0,0 +1,103 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import inspect + +import torch.nn as nn +from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel + + +_GOSSIP_DISABLED = False +try: + import gossip +except ImportError: + _GOSSIP_DISABLED = True + + +def DistributedFairseqModel(args, model, process_group=None): + """ + Wrap a *model* to support distributed data parallel training. + + This is similar to the built-in DistributedDataParallel, but allows + additional configuration of the DistributedDataParallel class to + use, and also provides easier access to the wrapped model by + forwarding requests for missing attributes to the wrapped model. + + Args: + args (argparse.Namespace): fairseq args + model (BaseFairseqModel): model to wrap + """ + # determine which DDP class to extend + assert isinstance(model, nn.Module) + if args.distributed_wrapper == "DDP" and args.ddp_backend == "c10d": + ddp_class = nn.parallel.DistributedDataParallel + init_kwargs = dict( + module=model, + device_ids=[args.device_id], + output_device=args.device_id, + broadcast_buffers=args.broadcast_buffers, + bucket_cap_mb=args.bucket_cap_mb, + process_group=process_group, + ) + # Maintain backward compatibility + if "check_reduction" in inspect.getargspec(ddp_class)[0]: + init_kwargs["check_reduction"] = True + if "find_unused_parameters" in inspect.getargspec(ddp_class)[0]: + init_kwargs["find_unused_parameters"] = args.find_unused_parameters + elif args.distributed_wrapper == "DDP" and args.ddp_backend == "no_c10d": + ddp_class = LegacyDistributedDataParallel + init_kwargs = dict( + module=model, + world_size=args.distributed_world_size, + buffer_size=2 ** 28, + process_group=process_group, + ) + elif args.distributed_wrapper == "SlowMo": + if _GOSSIP_DISABLED: + raise ImportError( + "Cannot find gossip library. Please install from: " + "github.com/facebookresearch/stochastic_gradient_push" + ) + ddp_class = gossip.GossipDataParallel + + # The values of slowmo_momentum below were obtained by tuning on the + # En-De 16 dataset by training the transformer_wmt_en_de_large model + if args.slowmo_momentum is None: + if args.distributed_world_size <= 16: + args.slowmo_momentum = 0.0 + elif args.distributed_world_size <= 32: + args.slowmo_momentum = 0.2 + elif args.distributed_world_size <= 64: + args.slowmo_momentum = 0.5 + else: + args.slowmo_momentum = 0.6 + + init_kwargs = dict( + module=model, + device_ids=[args.device_id], + output_device=args.device_id, + broadcast_buffers=args.broadcast_buffers, + nprocs_per_node=args.nprocs_per_node, + slowmo_momentum=args.slowmo_momentum, + localsgd=(args.slowmo_algorithm == "LocalSGD"), + localsgd_frequency=args.localsgd_frequency, + ) + else: + raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) + + class _DistributedFairseqModel(ddp_class): + """Extend DistributedDataParallel to check for missing + attributes in the wrapped module.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __getattr__(self, name): + wrapped_module = super().__getattr__("module") + if hasattr(wrapped_module, name): + return getattr(wrapped_module, name) + return super().__getattr__(name) + + return _DistributedFairseqModel(**init_kwargs) diff --git a/fairseq/models/fairseq_decoder.py b/fairseq/models/fairseq_decoder.py new file mode 100644 index 0000000..fb6c52d --- /dev/null +++ b/fairseq/models/fairseq_decoder.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Tuple + +import torch.nn as nn +from fairseq import utils +from torch import Tensor + + +class FairseqDecoder(nn.Module): + """Base class for decoders.""" + + def __init__(self, dictionary): + super().__init__() + self.dictionary = dictionary + self.onnx_trace = False + + def forward(self, prev_output_tokens, encoder_out=None, **kwargs): + """ + Args: + prev_output_tokens (LongTensor): shifted output tokens of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (dict, optional): output from the encoder, used for + encoder-side attention + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + x, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + x = self.output_layer(x) + return x, extra + + def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs): + """ + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def output_layer(self, features, **kwargs): + """ + Project features to the default output size, e.g., vocabulary size. + + Args: + features (Tensor): features returned by *extract_features*. + """ + raise NotImplementedError + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + + if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: + if sample is not None: + assert "target" in sample + target = sample["target"] + else: + target = None + out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) + return out.exp_() if not log_probs else out + + logits = net_output[0] + if log_probs: + return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) + else: + return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) + + def max_positions(self): + """Maximum input length supported by the decoder.""" + return 1e6 # an arbitrary large number + + def upgrade_state_dict(self, state_dict): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + def prepare_for_onnx_export_(self): + self.onnx_trace = True diff --git a/fairseq/models/fairseq_encoder.py b/fairseq/models/fairseq_encoder.py new file mode 100644 index 0000000..c8873da --- /dev/null +++ b/fairseq/models/fairseq_encoder.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, NamedTuple, Optional + +import torch +import torch.nn as nn +from torch import Tensor + + +EncoderOut = NamedTuple( + "EncoderOut", + [ + ("encoder_out", Tensor), # T x B x C + ("encoder_padding_mask", Optional[Tensor]), # B x T + ("encoder_embedding", Optional[Tensor]), # B x T x C + ("encoder_states", Optional[List[Tensor]]), # List[T x B x C] + ("src_tokens", Optional[Tensor]), # B x T + ("src_lengths", Optional[Tensor]), # B x 1 + ], +) + + +class FairseqEncoder(nn.Module): + """Base class for encoders.""" + + def __init__(self, dictionary): + super().__init__() + self.dictionary = dictionary + + def forward(self, src_tokens, src_lengths=None, **kwargs): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of shape + `(batch)` + """ + raise NotImplementedError + + def forward_torchscript(self, net_input: Dict[str, Tensor]): + """A TorchScript-compatible version of forward. + + Encoders which use additional arguments may want to override + this method for TorchScript compatibility. + """ + if torch.jit.is_scripting(): + return self.forward( + src_tokens=net_input["src_tokens"], + src_lengths=net_input["src_lengths"], + ) + else: + return self.forward_non_torchscript(net_input) + + @torch.jit.unused + def forward_non_torchscript(self, net_input: Dict[str, Tensor]): + encoder_input = { + k: v for k, v in net_input.items() if k != "prev_output_tokens" + } + return self.forward(**encoder_input) + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to `new_order`. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + `encoder_out` rearranged according to `new_order` + """ + raise NotImplementedError + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return 1e6 # an arbitrary large number + + def upgrade_state_dict(self, state_dict): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + + def _apply(m): + if hasattr(m, "set_num_updates") and m != self: + m.set_num_updates(num_updates) + + self.apply(_apply) diff --git a/fairseq/models/fairseq_incremental_decoder.py b/fairseq/models/fairseq_incremental_decoder.py new file mode 100644 index 0000000..cc72a0f --- /dev/null +++ b/fairseq/models/fairseq_incremental_decoder.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, Optional + +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.models import FairseqDecoder +from torch import Tensor + + +logger = logging.getLogger(__name__) + + +@with_incremental_state +class FairseqIncrementalDecoder(FairseqDecoder): + """Base class for incremental decoders. + + Incremental decoding is a special mode at inference time where the Model + only receives a single timestep of input corresponding to the previous + output token (for teacher forcing) and must produce the next output + *incrementally*. Thus the model must cache any long-term state that is + needed about the sequence, e.g., hidden states, convolutional states, etc. + + Compared to the standard :class:`FairseqDecoder` interface, the incremental + decoder interface allows :func:`forward` functions to take an extra keyword + argument (*incremental_state*) that can be used to cache state across + time-steps. + + The :class:`FairseqIncrementalDecoder` interface also defines the + :func:`reorder_incremental_state` method, which is used during beam search + to select and reorder the incremental state based on the selection of beams. + + To learn more about how incremental decoding works, refer to `this blog + `_. + """ + + def __init__(self, dictionary): + super().__init__(dictionary) + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): + """ + Args: + prev_output_tokens (LongTensor): shifted output tokens of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (dict, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict, optional): dictionary used for storing + state during :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): + """ + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder incremental state. + + This will be called when the order of the input has changed from the + previous time step. A typical use case is beam search, where the input + order changes between time steps based on the selection of beams. + """ + pass + + def reorder_incremental_state_scripting( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Main entry point for reordering the incremental state. + + Due to limitations in TorchScript, we call this function in + :class:`fairseq.sequence_generator.SequenceGenerator` instead of + calling :func:`reorder_incremental_state` directly. + """ + for module in self.modules(): + if hasattr(module, "reorder_incremental_state"): + result = module.reorder_incremental_state(incremental_state, new_order) + if result is not None: + incremental_state = result + + def set_beam_size(self, beam_size): + """Sets the beam size in the decoder and all children.""" + if getattr(self, "_beam_size", -1) != beam_size: + seen = set() + + def apply_set_beam_size(module): + if ( + module != self + and hasattr(module, "set_beam_size") + and module not in seen + ): + seen.add(module) + module.set_beam_size(beam_size) + + self.apply(apply_set_beam_size) + self._beam_size = beam_size diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py new file mode 100644 index 0000000..0c8d106 --- /dev/null +++ b/fairseq/models/fairseq_model.py @@ -0,0 +1,573 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base classes for various fairseq models. +""" + +import logging +from argparse import Namespace +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.checkpoint_utils import prune_state_dict +from fairseq.data import Dictionary +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + gen_parser_from_dataclass, +) +from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig +from torch import Tensor + + +logger = logging.getLogger(__name__) + + +class BaseFairseqModel(nn.Module): + """Base class for fairseq models.""" + + def __init__(self): + super().__init__() + self._is_generation_fast = False + + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + # do not set defaults so that settings defaults from various architectures still works + gen_parser_from_dataclass(parser, dc(), delete_default=True) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + raise NotImplementedError("Model must implement the build_model method") + + def get_targets(self, sample, net_output): + """Get targets from either the sample or the net's output.""" + return sample["target"] + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" + if hasattr(self, "decoder"): + return self.decoder.get_normalized_probs(net_output, log_probs, sample) + elif torch.is_tensor(net_output): + # syntactic sugar for simple models which don't have a decoder + # (e.g., the classification tutorial) + logits = net_output.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + def extract_features(self, *args, **kwargs): + """Similar to *forward* but only return features.""" + return self(*args, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return None + + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): + """Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`nn.Module`. Compared with that method + this additionally "upgrades" *state_dicts* from old checkpoints. + """ + + if model_cfg is None and args is not None: + logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + model_cfg = convert_namespace_to_omegaconf(args).model + + self.upgrade_state_dict(state_dict) + new_state_dict = prune_state_dict(state_dict, model_cfg) + return super().load_state_dict(new_state_dict, strict) + + def upgrade_state_dict(self, state_dict): + """Upgrade old state dicts to work with newer code.""" + self.upgrade_state_dict_named(state_dict, "") + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade old state dicts to work with newer code. + + Args: + state_dict (dict): state dictionary to upgrade, in place + name (str): the state dict key corresponding to the current module + """ + assert state_dict is not None + + def do_upgrade(m, prefix): + if len(prefix) > 0: + prefix += "." + + for n, c in m.named_children(): + name = prefix + n + if hasattr(c, "upgrade_state_dict_named"): + c.upgrade_state_dict_named(state_dict, name) + elif hasattr(c, "upgrade_state_dict"): + c.upgrade_state_dict(state_dict) + do_upgrade(c, name) + + do_upgrade(self, name) + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + + def _apply(m): + if hasattr(m, "set_num_updates") and m != self: + m.set_num_updates(num_updates) + + self.apply(_apply) + + def prepare_for_inference_(self, cfg: DictConfig): + """Prepare model for inference.""" + kwargs = {} + kwargs["beamable_mm_beam_size"] = ( + None + if getattr(cfg.generation, "no_beamable_mm", False) + else getattr(cfg.generation, "beam", 5) + ) + kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False) + if getattr(cfg.generation, "retain_dropout", False): + kwargs["retain_dropout"] = cfg.generation.retain_dropout + kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules + self.make_generation_fast_(**kwargs) + + def make_generation_fast_(self, **kwargs): + """ + Legacy entry point to optimize model for faster generation. + Prefer prepare_for_inference_. + """ + if self._is_generation_fast: + return # only apply once + self._is_generation_fast = True + + # remove weight norm from all modules in the network + def apply_remove_weight_norm(module): + try: + nn.utils.remove_weight_norm(module) + except ValueError: # this module didn't have weight norm + return + + self.apply(apply_remove_weight_norm) + + def apply_make_generation_fast_(module, prefix): + if len(prefix) > 0: + prefix += "." + + base_func = BaseFairseqModel.make_generation_fast_ + for n, m in module.named_modules(): + if ( + m != self + and hasattr(m, "make_generation_fast_") + # don't call this implementation again, e.g., if + # children modules also inherit from BaseFairseqModel + and m.make_generation_fast_.__func__ is not base_func + ): + name = prefix + n + m.make_generation_fast_(name=name, **kwargs) + + apply_make_generation_fast_(self, "") + + def train(mode=True): + if mode: + raise RuntimeError("cannot train after make_generation_fast") + + # this model should no longer be used for training + self.eval() + self.train = train + + def prepare_for_onnx_export_(self, **kwargs): + """Make model exportable via ONNX trace.""" + seen = set() + + def apply_prepare_for_onnx_export_(module): + if ( + module != self + and hasattr(module, "prepare_for_onnx_export_") + and module not in seen + ): + seen.add(module) + module.prepare_for_onnx_export_(**kwargs) + + self.apply(apply_prepare_for_onnx_export_) + + def prepare_for_tpu_(self, **kwargs): + """Optionally modify model for use on TPUs.""" + seen = set() + + def apply_prepare_for_tpu_(module): + if ( + module != self + and hasattr(module, "prepare_for_tpu_") + and module not in seen + ): + seen.add(module) + module.prepare_for_tpu_(**kwargs) + + self.apply(apply_prepare_for_tpu_) + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + **kwargs, + ): + """ + Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model + file. Downloads and caches the pre-trained model file if needed. + + The base implementation returns a + :class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to + generate translations or sample from language models. The underlying + :class:`~fairseq.models.FairseqModel` can be accessed via the + *generator.models* attribute. + + Other models may override this to implement custom hub interfaces. + + Args: + model_name_or_path (str): either the name of a pre-trained model to + load or a path/URL to a pre-trained model state dict + checkpoint_file (str, optional): colon-separated list of checkpoint + files in the model archive to ensemble (default: 'model.pt') + data_name_or_path (str, optional): point args.data to the archive + at the given path/URL. Can start with '.' or './' to reuse the + model archive path. + """ + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + **kwargs, + ) + logger.info(x["args"]) + return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"]) + + @classmethod + def hub_models(cls): + return {} + + +class FairseqEncoderDecoderModel(BaseFairseqModel): + """Base class for encoder-decoder models. + + Args: + encoder (FairseqEncoder): the encoder + decoder (FairseqDecoder): the decoder + """ + + def __init__(self, encoder, decoder): + super().__init__() + + self.encoder = encoder + self.decoder = decoder + assert isinstance(self.encoder, FairseqEncoder) + assert isinstance(self.decoder, FairseqDecoder) + + def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + """ + Run the forward pass for an encoder-decoder model. + + First feed a batch of source tokens through the encoder. Then, feed the + encoder output and previous decoder outputs (i.e., teacher forcing) to + the decoder to produce the next outputs:: + + encoder_out = self.encoder(src_tokens, src_lengths) + return self.decoder(prev_output_tokens, encoder_out) + + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + decoder_out = self.decoder( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + return decoder_out + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + + def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + features = self.decoder.extract_features( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + return features + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return (self.encoder.max_positions(), self.decoder.max_positions()) + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + +class FairseqModel(FairseqEncoderDecoderModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + utils.deprecation_warning( + "FairseqModel is deprecated, please use FairseqEncoderDecoderModel " + "or BaseFairseqModel instead", + stacklevel=4, + ) + + +class FairseqMultiModel(BaseFairseqModel): + """Base class for combining multiple encoder-decoder models.""" + + def __init__(self, encoders, decoders): + super().__init__() + assert encoders.keys() == decoders.keys() + self.keys = list(encoders.keys()) + for key in self.keys: + assert isinstance(encoders[key], FairseqEncoder) + assert isinstance(decoders[key], FairseqDecoder) + + self.models = nn.ModuleDict( + { + key: FairseqEncoderDecoderModel(encoders[key], decoders[key]) + for key in self.keys + } + ) + + @staticmethod + def build_shared_embeddings( + dicts: Dict[str, Dictionary], + langs: List[str], + embed_dim: int, + build_embedding: callable, + pretrained_embed_path: Optional[str] = None, + ): + """ + Helper function to build shared embeddings for a set of languages after + checking that all dicts corresponding to those languages are equivalent. + + Args: + dicts: Dict of lang_id to its corresponding Dictionary + langs: languages that we want to share embeddings for + embed_dim: embedding dimension + build_embedding: callable function to actually build the embedding + pretrained_embed_path: Optional path to load pretrained embeddings + """ + shared_dict = dicts[langs[0]] + if any(dicts[lang] != shared_dict for lang in langs): + raise ValueError( + "--share-*-embeddings requires a joined dictionary: " + "--share-encoder-embeddings requires a joined source " + "dictionary, --share-decoder-embeddings requires a joined " + "target dictionary, and --share-all-embeddings requires a " + "joint source + target dictionary." + ) + return build_embedding(shared_dict, embed_dim, pretrained_embed_path) + + def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + raise NotImplementedError + + def max_positions(self): + """Maximum length supported by the model.""" + return { + key: ( + self.models[key].encoder.max_positions(), + self.models[key].decoder.max_positions(), + ) + for key in self.keys + } + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return min(model.decoder.max_positions() for model in self.models.values()) + + @property + def encoder(self): + return self.models[self.keys[0]].encoder + + @property + def decoder(self): + return self.models[self.keys[0]].decoder + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg=None, + args: Optional[Namespace] = None, + ): + """Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`nn.Module`. Compared with that method + this additionally "upgrades" *state_dicts* from old checkpoints. + """ + + if model_cfg is None and args is not None: + logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + model_cfg = convert_namespace_to_omegaconf(args).model + + self.upgrade_state_dict(state_dict) + new_state_dict = prune_state_dict(state_dict, model_cfg) + return super().load_state_dict(new_state_dict, strict) + + +class FairseqLanguageModel(BaseFairseqModel): + """Base class for decoder-only models. + + Args: + decoder (FairseqDecoder): the decoder + """ + + def __init__(self, decoder): + super().__init__() + self.decoder = decoder + assert isinstance(self.decoder, FairseqDecoder) + + def forward(self, src_tokens, **kwargs): + """ + Run the forward pass for a decoder-only model. + + Feeds a batch of tokens through the decoder to predict the next tokens. + + Args: + src_tokens (LongTensor): tokens on which to condition the decoder, + of shape `(batch, tgt_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + + Returns: + tuple: + - the decoder's output of shape `(batch, seq_len, vocab)` + - a dictionary with any model-specific outputs + """ + return self.decoder(src_tokens, **kwargs) + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + + def extract_features(self, src_tokens, **kwargs): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, seq_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + return self.decoder.extract_features(src_tokens, **kwargs) + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return self.decoder.max_positions() + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + @property + def supported_targets(self): + return {"future"} + + +class FairseqEncoderModel(BaseFairseqModel): + """Base class for encoder-only models. + + Args: + encoder (FairseqEncoder): the encoder + """ + + def __init__(self, encoder): + super().__init__() + self.encoder = encoder + assert isinstance(self.encoder, FairseqEncoder) + + def forward(self, src_tokens, src_lengths, **kwargs): + """ + Run the forward pass for a encoder-only model. + + Feeds a batch of tokens through the encoder to generate features. + + Args: + src_tokens (LongTensor): input tokens of shape `(batch, src_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + + Returns: + the encoder's output, typically of shape `(batch, src_len, features)` + """ + return self.encoder(src_tokens, src_lengths, **kwargs) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + encoder_out = net_output["encoder_out"] + if torch.is_tensor(encoder_out): + logits = encoder_out.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + def max_positions(self): + """Maximum length supported by the model.""" + return self.encoder.max_positions() diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py new file mode 100644 index 0000000..7614c33 --- /dev/null +++ b/fairseq/models/transformer.py @@ -0,0 +1,1036 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.modules import ( + AdaptiveSoftmax, + FairseqDropout, + LayerDropModuleList, + LayerNorm, + PositionalEmbedding, + SinusoidalPositionalEmbedding, + TransformerDecoderLayer, + TransformerEncoderLayer, +) +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ +from torch import Tensor + + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 + + +@register_model("transformer") +class TransformerModel(FairseqEncoderDecoderModel): + """ + Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) + `_. + + Args: + encoder (TransformerEncoder): the encoder + decoder (TransformerDecoder): the decoder + + The Transformer model provides the following named architectures and + command-line arguments: + + .. argparse:: + :ref: fairseq.models.transformer_parser + :prog: + """ + + @classmethod + def hub_models(cls): + # fmt: off + + def moses_subword(path): + return { + 'path': path, + 'tokenizer': 'moses', + 'bpe': 'subword_nmt', + } + + def moses_fastbpe(path): + return { + 'path': path, + 'tokenizer': 'moses', + 'bpe': 'fastbpe', + } + + return { + 'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'), + 'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2', + 'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'), + 'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'), + 'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'), + 'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'), + 'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'), + 'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'), + 'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'), + 'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'), + 'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'), + } + # fmt: on + + def __init__(self, args, encoder, decoder): + super().__init__(encoder, decoder) + self.args = args + self.supports_align_args = True + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--activation-fn', + choices=utils.get_available_activation_fns(), + help='activation function to use') + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--attention-dropout', type=float, metavar='D', + help='dropout probability for attention weights') + parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', + help='dropout probability after activation in FFN.') + parser.add_argument('--encoder-embed-path', type=str, metavar='STR', + help='path to pre-trained encoder embedding') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', + help='encoder embedding dimension for FFN') + parser.add_argument('--encoder-layers', type=int, metavar='N', + help='num encoder layers') + parser.add_argument('--encoder-attention-heads', type=int, metavar='N', + help='num encoder attention heads') + parser.add_argument('--encoder-normalize-before', action='store_true', + help='apply layernorm before each encoder block') + parser.add_argument('--encoder-learned-pos', action='store_true', + help='use learned positional embeddings in the encoder') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', + help='decoder embedding dimension for FFN') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='num decoder layers') + parser.add_argument('--decoder-attention-heads', type=int, metavar='N', + help='num decoder attention heads') + parser.add_argument('--decoder-learned-pos', action='store_true', + help='use learned positional embeddings in the decoder') + parser.add_argument('--decoder-normalize-before', action='store_true', + help='apply layernorm before each decoder block') + parser.add_argument('--decoder-output-dim', type=int, metavar='N', + help='decoder output dimension (extra linear layer ' + 'if different from decoder embed dim') + parser.add_argument('--share-decoder-input-output-embed', action='store_true', + help='share decoder input and output embeddings') + parser.add_argument('--share-all-embeddings', action='store_true', + help='share encoder, decoder and output embeddings' + ' (requires shared dictionary and embed dim)') + parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', + help='if set, disables positional embeddings (outside self attention)') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion'), + parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', + help='sets adaptive softmax dropout for the tail projections') + parser.add_argument('--layernorm-embedding', action='store_true', + help='add layernorm to embedding') + parser.add_argument('--no-scale-embedding', action='store_true', + help='if True, dont scale embeddings') + parser.add_argument('--checkpoint-activations', action='store_true', + help='checkpoint activations at each layer, which saves GPU ' + 'memory usage at the cost of some additional compute') + # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) + parser.add_argument('--no-cross-attention', default=False, action='store_true', + help='do not perform cross-attention') + parser.add_argument('--cross-self-attention', default=False, action='store_true', + help='perform cross+self-attention') + # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0, + help='LayerDrop probability for encoder') + parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0, + help='LayerDrop probability for decoder') + parser.add_argument('--encoder-layers-to-keep', default=None, + help='which layers to *keep* when pruning as a comma-separated list') + parser.add_argument('--decoder-layers-to-keep', default=None, + help='which layers to *keep* when pruning as a comma-separated list') + # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) + parser.add_argument('--quant-noise-pq', type=float, metavar='D', default=0, + help='iterative PQ quantization noise at training time') + parser.add_argument('--quant-noise-pq-block-size', type=int, metavar='D', default=8, + help='block size of quantization noise at training time') + parser.add_argument('--quant-noise-scalar', type=float, metavar='D', default=0, + help='scalar quantization noise and scalar quantization at training time') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if args.encoder_layers_to_keep: + args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + + if getattr(args, "max_source_positions", None) is None: + args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + if args.share_all_embeddings: + if src_dict != tgt_dict: + raise ValueError("--share-all-embeddings requires a joined dictionary") + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + encoder_embed_tokens = cls.build_embedding( + args, src_dict, args.encoder_embed_dim, args.encoder_embed_path + ) + decoder_embed_tokens = encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + encoder_embed_tokens = cls.build_embedding( + args, src_dict, args.encoder_embed_dim, args.encoder_embed_path + ) + decoder_embed_tokens = cls.build_embedding( + args, tgt_dict, args.decoder_embed_dim, args.decoder_embed_path + ) + + encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) + decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + return cls(args, encoder, decoder) + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + + emb = Embedding(num_embeddings, embed_dim, padding_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return TransformerEncoder(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return TransformerDecoder( + args, + tgt_dict, + embed_tokens, + no_encoder_attn=getattr(args, "no_cross_attention", False), + ) + + # TorchScript doesn't support optional arguments with variable length (**kwargs). + # Current workaround is to add union of all arguments in child classes. + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + return_all_hiddens: bool = True, + features_only: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + Run the forward pass for an encoder-decoder model. + + Copied from the base class, but without ``**kwargs``, + which are not supported by TorchScript. + """ + encoder_out = self.encoder( + src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens + ) + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + features_only=features_only, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens, + ) + return decoder_out + + # Since get_normalized_probs is in the Fairseq Model which is not scriptable, + # I rewrite the get_normalized_probs from Base Class to call the + # helper function in the Base Class. + @torch.jit.export + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + +class TransformerEncoder(FairseqEncoder): + """ + Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`TransformerEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_tokens (torch.nn.Embedding): input embedding + """ + + def __init__(self, args, dictionary, embed_tokens): + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.encoder_layerdrop = args.encoder_layerdrop + + embed_dim = embed_tokens.embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_source_positions = args.max_source_positions + + self.embed_tokens = embed_tokens + + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + + self.embed_positions = ( + PositionalEmbedding( + args.max_source_positions, + embed_dim, + self.padding_idx, + learned=args.encoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + if not args.adaptive_input and args.quant_noise_pq > 0: + self.quant_noise = apply_quant_noise_( + nn.Linear(embed_dim, embed_dim, bias=False), + args.quant_noise_pq, + args.quant_noise_pq_block_size, + ) + else: + self.quant_noise = None + + if self.encoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.encoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [self.build_encoder_layer(args) for i in range(args.encoder_layers)] + ) + self.num_layers = len(self.layers) + + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def build_encoder_layer(self, args): + layer = TransformerEncoderLayer(args) + if getattr(args, "checkpoint_activations", False): + layer = checkpoint_wrapper(layer) + return layer + + def forward_embedding( + self, src_tokens, token_embedding: Optional[torch.Tensor] = None + ): + # embed tokens and positions + if token_embedding is None: + token_embedding = self.embed_tokens(src_tokens) + x = embed = self.embed_scale * token_embedding + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout_module(x) + if self.quant_noise is not None: + x = self.quant_noise(x) + return x, embed + + def forward( + self, + src_tokens, + src_lengths, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + namedtuple: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # compute padding mask + encoder_padding_mask = src_tokens.eq(self.padding_idx) + + encoder_states = [] if return_all_hiddens else None + + # encoder layers + for layer in self.layers: + x = layer(x, encoder_padding_mask) + if return_all_hiddens: + assert encoder_states is not None + encoder_states.append(x) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + return EncoderOut( + encoder_out=x, # T x B x C + encoder_padding_mask=encoder_padding_mask, # B x T + encoder_embedding=encoder_embedding, # B x T x C + encoder_states=encoder_states, # List[T x B x C] + src_tokens=None, + src_lengths=None, + ) + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + """ + Since encoder_padding_mask and encoder_embedding are both of type + Optional[Tensor] in EncoderOut, they need to be copied as local + variables for Torchscript Optional refinement + """ + encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask + encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding + + new_encoder_out = ( + encoder_out.encoder_out + if encoder_out.encoder_out is None + else encoder_out.encoder_out.index_select(1, new_order) + ) + new_encoder_padding_mask = ( + encoder_padding_mask + if encoder_padding_mask is None + else encoder_padding_mask.index_select(0, new_order) + ) + new_encoder_embedding = ( + encoder_embedding + if encoder_embedding is None + else encoder_embedding.index_select(0, new_order) + ) + src_tokens = encoder_out.src_tokens + if src_tokens is not None: + src_tokens = src_tokens.index_select(0, new_order) + + src_lengths = encoder_out.src_lengths + if src_lengths is not None: + src_lengths = src_lengths.index_select(0, new_order) + + encoder_states = encoder_out.encoder_states + if encoder_states is not None: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return EncoderOut( + encoder_out=new_encoder_out, # T x B x C + encoder_padding_mask=new_encoder_padding_mask, # B x T + encoder_embedding=new_encoder_embedding, # B x T x C + encoder_states=encoder_states, # List[T x B x C] + src_tokens=src_tokens, # B x T + src_lengths=src_lengths, # B x 1 + ) + + def max_positions(self): + """Maximum input length supported by the encoder.""" + if self.embed_positions is None: + return self.max_source_positions + return min(self.max_source_positions, self.embed_positions.max_positions) + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): + weights_key = "{}.embed_positions.weights".format(name) + if weights_key in state_dict: + print("deleting {0}".format(weights_key)) + del state_dict[weights_key] + state_dict[ + "{}.embed_positions._float_tensor".format(name) + ] = torch.FloatTensor(1) + for i in range(self.num_layers): + # update layer norms + self.layers[i].upgrade_state_dict_named( + state_dict, "{}.layers.{}".format(name, i) + ) + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + return state_dict + + +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + self.args = args + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + self._future_mask = torch.empty(0) + + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.decoder_layerdrop = args.decoder_layerdrop + self.share_input_output_embed = args.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = args.decoder_embed_dim + self.embed_dim = embed_dim + self.output_embed_dim = args.decoder_output_dim + + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = args.max_target_positions + + self.embed_tokens = embed_tokens + + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + + if not args.adaptive_input and args.quant_noise_pq > 0: + self.quant_noise = apply_quant_noise_( + nn.Linear(embed_dim, embed_dim, bias=False), + args.quant_noise_pq, + args.quant_noise_pq_block_size, + ) + else: + self.quant_noise = None + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + self.embed_positions = ( + PositionalEmbedding( + self.max_target_positions, + embed_dim, + self.padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + self.cross_self_attention = getattr(args, "cross_self_attention", False) + + if self.decoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.decoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + self.build_decoder_layer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + ) + self.num_layers = len(self.layers) + + if args.decoder_normalize_before and not getattr( + args, "no_decoder_final_norm", False + ): + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + self.project_out_dim = ( + Linear(embed_dim, self.output_embed_dim, bias=False) + if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights + else None + ) + + self.adaptive_softmax = None + self.output_projection = None + if args.adaptive_softmax_cutoff is not None: + self.adaptive_softmax = AdaptiveSoftmax( + len(dictionary), + self.output_embed_dim, + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), + dropout=args.adaptive_softmax_dropout, + adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, + factor=args.adaptive_softmax_factor, + tie_proj=args.tie_adaptive_proj, + ) + elif self.share_input_output_embed: + self.output_projection = nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = self.embed_tokens.weight + else: + self.output_projection = nn.Linear( + self.output_embed_dim, len(dictionary), bias=False + ) + nn.init.normal_( + self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5 + ) + + def build_decoder_layer(self, args, no_encoder_attn=False): + layer = TransformerDecoderLayer(args, no_encoder_attn) + if getattr(args, "checkpoint_activations", False): + layer = checkpoint_wrapper(layer) + return layer + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[EncoderOut] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + if not features_only: + x = self.output_layer(x) + return x, extra + + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[EncoderOut] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + return self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + + """ + A scriptable subclass of this class has an extract_features method and calls + super().extract_features, but super() is not supported in torchscript. Aa copy of + this function is made to be used in the subclass instead. + """ + + def extract_features_scriptable( + self, + prev_output_tokens, + encoder_out: Optional[EncoderOut] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + Similar to *forward* but only return features. + + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + + # decoder layers + attn: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, _ = layer( + x, + encoder_out.encoder_out if encoder_out is not None else None, + encoder_out.encoder_padding_mask if encoder_out is not None else None, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": [attn], "inner_states": inner_states} + + def output_layer(self, features): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + return self.output_projection(features) + else: + return features + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. + if ( + self._future_mask.size(0) == 0 + or (not self._future_mask.device == tensor.device) + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1 + ) + self._future_mask = self._future_mask.to(tensor) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): + weights_key = "{}.embed_positions.weights".format(name) + if weights_key in state_dict: + del state_dict[weights_key] + state_dict[ + "{}.embed_positions._float_tensor".format(name) + ] = torch.FloatTensor(1) + + if f"{name}.output_projection.weight" not in state_dict: + if self.share_input_output_embed: + embed_out_key = f"{name}.embed_tokens.weight" + else: + embed_out_key = f"{name}.embed_out" + if embed_out_key in state_dict: + state_dict[f"{name}.output_projection.weight"] = state_dict[ + embed_out_key + ] + if not self.share_input_output_embed: + del state_dict[embed_out_key] + + for i in range(self.num_layers): + # update layer norms + layer_norm_map = { + "0": "self_attn_layer_norm", + "1": "encoder_attn_layer_norm", + "2": "final_layer_norm", + } + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) + if k in state_dict: + state_dict[ + "{}.layers.{}.{}.{}".format(name, i, new, m) + ] = state_dict[k] + del state_dict[k] + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + + return state_dict + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +@register_model_architecture("transformer", "transformer") +def base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.no_cross_attention = getattr(args, "no_cross_attention", False) + args.cross_self_attention = getattr(args, "cross_self_attention", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + + +@register_model_architecture("transformer", "transformer_iwslt_de_en") +def transformer_iwslt_de_en(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 6) + base_architecture(args) + + +@register_model_architecture("transformer", "transformer_wmt_en_de") +def transformer_wmt_en_de(args): + base_architecture(args) + + +# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) +@register_model_architecture("transformer", "transformer_vaswani_wmt_en_de_big") +def transformer_vaswani_wmt_en_de_big(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.3) + base_architecture(args) + + +@register_model_architecture("transformer", "transformer_vaswani_wmt_en_fr_big") +def transformer_vaswani_wmt_en_fr_big(args): + args.dropout = getattr(args, "dropout", 0.1) + transformer_vaswani_wmt_en_de_big(args) + + +@register_model_architecture("transformer", "transformer_wmt_en_de_big") +def transformer_wmt_en_de_big(args): + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + transformer_vaswani_wmt_en_de_big(args) + + +# default parameters used in tensor2tensor implementation +@register_model_architecture("transformer", "transformer_wmt_en_de_big_t2t") +def transformer_wmt_en_de_big_t2t(args): + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.1) + transformer_vaswani_wmt_en_de_big(args) diff --git a/fairseq/models/transformer_align.py b/fairseq/models/transformer_align.py new file mode 100644 index 0000000..eaf585b --- /dev/null +++ b/fairseq/models/transformer_align.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer import ( + TransformerModel, + base_architecture, + transformer_wmt_en_de_big, +) + + +@register_model("transformer_align") +class TransformerAlignModel(TransformerModel): + """ + See "Jointly Learning to Align and Translate with Transformer + Models" (Garg et al., EMNLP 2019). + """ + + def __init__(self, encoder, decoder, args): + super().__init__(args, encoder, decoder) + self.alignment_heads = args.alignment_heads + self.alignment_layer = args.alignment_layer + self.full_context_alignment = args.full_context_alignment + + @staticmethod + def add_args(parser): + # fmt: off + super(TransformerAlignModel, TransformerAlignModel).add_args(parser) + parser.add_argument('--alignment-heads', type=int, metavar='D', + help='Number of cross attention heads per layer to supervised with alignments') + parser.add_argument('--alignment-layer', type=int, metavar='D', + help='Layer number which has to be supervised. 0 corresponding to the bottommost layer.') + parser.add_argument('--full-context-alignment', action='store_true', + help='Whether or not alignment is supervised conditioned on the full target context.') + # fmt: on + + @classmethod + def build_model(cls, args, task): + # set any default arguments + transformer_align(args) + + transformer_model = TransformerModel.build_model(args, task) + return TransformerAlignModel( + transformer_model.encoder, transformer_model.decoder, args + ) + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + encoder_out = self.encoder(src_tokens, src_lengths) + return self.forward_decoder(prev_output_tokens, encoder_out) + + def forward_decoder( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + features_only=False, + **extra_args, + ): + attn_args = { + "alignment_layer": self.alignment_layer, + "alignment_heads": self.alignment_heads, + } + decoder_out = self.decoder(prev_output_tokens, encoder_out, **attn_args) + + if self.full_context_alignment: + attn_args["full_context_alignment"] = self.full_context_alignment + _, alignment_out = self.decoder( + prev_output_tokens, + encoder_out, + features_only=True, + **attn_args, + **extra_args, + ) + decoder_out[1]["attn"] = alignment_out["attn"] + + return decoder_out + + +@register_model_architecture("transformer_align", "transformer_align") +def transformer_align(args): + args.alignment_heads = getattr(args, "alignment_heads", 1) + args.alignment_layer = getattr(args, "alignment_layer", 4) + args.full_context_alignment = getattr(args, "full_context_alignment", False) + base_architecture(args) + + +@register_model_architecture("transformer_align", "transformer_wmt_en_de_big_align") +def transformer_wmt_en_de_big_align(args): + args.alignment_heads = getattr(args, "alignment_heads", 1) + args.alignment_layer = getattr(args, "alignment_layer", 4) + transformer_wmt_en_de_big(args) diff --git a/fairseq/models/transformer_from_pretrained_xlm.py b/fairseq/models/transformer_from_pretrained_xlm.py new file mode 100644 index 0000000..236d994 --- /dev/null +++ b/fairseq/models/transformer_from_pretrained_xlm.py @@ -0,0 +1,152 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Any, Dict + +from fairseq import checkpoint_utils +from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer import ( + TransformerDecoder, + TransformerEncoder, + TransformerModel, + base_architecture as transformer_base_architecture, +) + + +@register_model("transformer_from_pretrained_xlm") +class TransformerFromPretrainedXLMModel(TransformerModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + TransformerModel.add_args(parser) + parser.add_argument( + "--pretrained-xlm-checkpoint", + type=str, + metavar="STR", + help="XLM model to use for initializing transformer encoder and/or decoder", + ) + parser.add_argument( + "--init-encoder-only", + action="store_true", + help="if set, don't load the XLM weights and embeddings into decoder", + ) + parser.add_argument( + "--init-decoder-only", + action="store_true", + help="if set, don't load the XLM weights and embeddings into encoder", + ) + + @classmethod + def build_model(self, args, task, cls_dictionary=MaskedLMDictionary): + assert hasattr(args, "pretrained_xlm_checkpoint"), ( + "You must specify a path for --pretrained-xlm-checkpoint to use " + "--arch transformer_from_pretrained_xlm" + ) + assert isinstance(task.source_dictionary, cls_dictionary) and isinstance( + task.target_dictionary, cls_dictionary + ), ( + "You should use a MaskedLMDictionary when using --arch " + "transformer_from_pretrained_xlm because the pretrained XLM model " + "was trained using data binarized with MaskedLMDictionary. " + "For translation, you may want to use --task " + "translation_from_pretrained_xlm" + ) + assert not ( + getattr(args, "init_encoder_only", False) + and getattr(args, "init_decoder_only", False) + ), "Only one of --init-encoder-only and --init-decoder-only can be set." + return super().build_model(args, task) + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return TransformerEncoderFromPretrainedXLM(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return TransformerDecoderFromPretrainedXLM(args, tgt_dict, embed_tokens) + + +def upgrade_state_dict_with_xlm_weights( + state_dict: Dict[str, Any], pretrained_xlm_checkpoint: str +) -> Dict[str, Any]: + """ + Load XLM weights into a Transformer encoder or decoder model. + + Args: + state_dict: state dict for either TransformerEncoder or + TransformerDecoder + pretrained_xlm_checkpoint: checkpoint to load XLM weights from + + Raises: + AssertionError: If architecture (num layers, attention heads, etc.) + does not match between the current Transformer encoder or + decoder and the pretrained_xlm_checkpoint + """ + if not os.path.exists(pretrained_xlm_checkpoint): + raise IOError("Model file not found: {}".format(pretrained_xlm_checkpoint)) + + state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint) + xlm_state_dict = state["model"] + for key in xlm_state_dict.keys(): + + for search_key in ["embed_tokens", "embed_positions", "layers"]: + if search_key in key: + subkey = key[key.find(search_key) :] + assert subkey in state_dict, ( + "{} Transformer encoder / decoder " + "state_dict does not contain {}. Cannot " + "load {} from pretrained XLM checkpoint " + "{} into Transformer.".format( + str(state_dict.keys()), subkey, key, pretrained_xlm_checkpoint + ) + ) + + state_dict[subkey] = xlm_state_dict[key] + return state_dict + + +class TransformerEncoderFromPretrainedXLM(TransformerEncoder): + def __init__(self, args, dictionary, embed_tokens): + super().__init__(args, dictionary, embed_tokens) + if getattr(args, "init_decoder_only", False): + # Don't load XLM weights for encoder if --init-decoder-only + return + + assert hasattr(args, "pretrained_xlm_checkpoint"), ( + "--pretrained-xlm-checkpoint must be specified to load Transformer " + "encoder from pretrained XLM" + ) + xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights( + state_dict=self.state_dict(), + pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint, + ) + self.load_state_dict(xlm_loaded_state_dict, strict=True) + + +class TransformerDecoderFromPretrainedXLM(TransformerDecoder): + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + super().__init__(args, dictionary, embed_tokens, no_encoder_attn) + if getattr(args, "init_encoder_only", False): + # Don't load XLM weights for decoder if --init-encoder-only + return + assert hasattr(args, "pretrained_xlm_checkpoint"), ( + "--pretrained-xlm-checkpoint must be specified to load Transformer " + "decoder from pretrained XLM" + ) + + xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights( + state_dict=self.state_dict(), + pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint, + ) + self.load_state_dict(xlm_loaded_state_dict, strict=True) + + +@register_model_architecture( + "transformer_from_pretrained_xlm", "transformer_from_pretrained_xlm" +) +def base_architecture(args): + transformer_base_architecture(args) diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py new file mode 100644 index 0000000..9467b25 --- /dev/null +++ b/fairseq/models/transformer_lm.py @@ -0,0 +1,397 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass, field +from typing import Optional + +from fairseq import options, utils +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import ( + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder +from omegaconf import II + + +DEFAULT_MAX_TARGET_POSITIONS = 1024 + + +@dataclass +class TransformerLanguageModelConfig(FairseqDataclass): + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="relu", metadata={"help": "activation function to use"} + ) + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) + attention_dropout: float = field( + default=0.0, metadata={"help": "dropout probability for attention weights"} + ) + activation_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN."} + ) + relu_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN."} + ) + decoder_embed_dim: int = field( + default=512, metadata={"help": "decoder embedding dimension"} + ) + decoder_output_dim: int = field( + default=512, metadata={"help": "decoder output dimension"} + ) + decoder_input_dim: int = field( + default=512, metadata={"help": "decoder input dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=2048, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) + decoder_attention_heads: int = field( + default=8, metadata={"help": "num decoder attention heads"} + ) + decoder_normalize_before: bool = field( + default=False, metadata={"help": "apply layernorm before each decoder block"} + ) + no_decoder_final_norm: bool = field( + default=False, + metadata={"help": "don't add an extra layernorm after the last decoder block"}, + ) + adaptive_softmax_cutoff: Optional[str] = field( + default=None, + metadata={ + "help": "comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion" + }, + ) + adaptive_softmax_dropout: float = field( + default=0, + metadata={"help": "sets adaptive softmax dropout for the tail projections"}, + ) + adaptive_softmax_factor: float = field( + default=4, metadata={"help": "adaptive input factor"} + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings (outside self attention)" + }, + ) + share_decoder_input_output_embed: bool = field( + default=False, metadata={"help": "share decoder input and output embeddings"} + ) + character_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, uses character embedding convolutions to produce token embeddings" + }, + ) + character_filters: str = field( + default="[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", + metadata={"help": "size of character embeddings"}, + ) + character_embedding_dim: int = field( + default=4, metadata={"help": "size of character embeddings"} + ) + char_embedder_highway_layers: int = field( + default=2, + metadata={"help": "number of highway layers for character token embeddder"}, + ) + adaptive_input: bool = field( + default=False, metadata={"help": "if set, uses adaptive input"} + ) + adaptive_input_factor: float = field( + default=4, metadata={"help": "adaptive input factor"} + ) + adaptive_input_cutoff: Optional[str] = field( + default=None, + metadata={"help": "comma separated list of adaptive input cutoff points."}, + ) + tie_adaptive_weights: bool = field( + default=False, + metadata={ + "help": "if set, ties the weights of adaptive softmax and adaptive input" + }, + ) + tie_adaptive_proj: bool = field( + default=False, + metadata={ + "help": "if set, ties the projection weights of adaptive softmax and adaptive input" + }, + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "LayerDrop probability for decoder"} + ) + decoder_layers_to_keep: Optional[str] = field( + default=None, + metadata={ + "help": "which layers to *keep* when pruning as a comma-separated list" + }, + ) + layernorm_embedding: bool = field( + default=False, metadata={"help": "add layernorm to embedding"} + ) + no_scale_embedding: bool = field( + default=False, metadata={"help": "if True, dont scale embeddings"} + ) + checkpoint_activations: bool = field( + default=False, metadata={"help": "checkpoint activations at each layer"} + ) + quant_noise_pq: float = field( + default=0.0, + metadata={"help": "iterative PQ quantization noise at training time"}, + ) + quant_noise_pq_block_size: int = field( + default=8, + metadata={"help": "block size of quantization noise at training time"}, + ) + # TODO common var add to parent + quant_noise_scalar: float = field( + default=0.0, + metadata={ + "help": "scalar quantization noise and scalar quantization at training time" + }, + ) + add_bos_token: bool = II("task.add_bos_token") + tokens_per_sample: int = II("task.tokens_per_sample") + max_target_positions: Optional[int] = II("task.max_target_positions") + tpu: bool = II("common.tpu") + + +@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) +class TransformerLanguageModel(FairseqLanguageModel): + @classmethod + def hub_models(cls): + def moses_fastbpe(path): + return {"path": path, "tokenizer": "moses", "bpe": "fastbpe"} + + return { + "transformer_lm.gbw.adaptive_huge": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2", + "transformer_lm.wiki103.adaptive": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2", + "transformer_lm.wmt19.en": moses_fastbpe( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2" + ), + "transformer_lm.wmt19.de": moses_fastbpe( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2" + ), + "transformer_lm.wmt19.ru": moses_fastbpe( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2" + ), + } + + def __init__(self, decoder): + super().__init__(decoder) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_lm_architecture(args) + + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) + + if args.character_embeddings: + embed_tokens = CharacterTokenEmbedder( + task.source_dictionary, + eval(args.character_filters), + args.character_embedding_dim, + args.decoder_embed_dim, + args.char_embedder_highway_layers, + ) + elif args.adaptive_input: + embed_tokens = AdaptiveInput( + len(task.source_dictionary), + task.source_dictionary.pad(), + args.decoder_input_dim, + args.adaptive_input_factor, + args.decoder_embed_dim, + options.eval_str_list(args.adaptive_input_cutoff, type=int), + args.quant_noise_pq, + args.quant_noise_pq_block_size, + ) + else: + embed_tokens = cls.build_embedding( + args, task.source_dictionary, args.decoder_input_dim + ) + + if args.tie_adaptive_weights: + assert args.adaptive_input + assert args.adaptive_input_factor == args.adaptive_softmax_factor + assert ( + args.adaptive_softmax_cutoff == args.adaptive_input_cutoff + ), "{} != {}".format( + args.adaptive_softmax_cutoff, args.adaptive_input_cutoff + ) + assert args.decoder_input_dim == args.decoder_output_dim + + decoder = TransformerDecoder( + args, task.target_dictionary, embed_tokens, no_encoder_attn=True + ) + return cls(decoder) + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad()) + return embed_tokens + + +@register_model_architecture("transformer_lm", "transformer_lm") +def base_lm_architecture(args): + # backward compatibility for older model checkpoints + if hasattr(args, "no_tie_adaptive_proj"): + # previous models defined --no-tie-adaptive-proj, so use the existence of + # that option to determine if this is an "old" model checkpoint + args.no_decoder_final_norm = True # old models always set this to True + if args.no_tie_adaptive_proj is False: + args.tie_adaptive_proj = True + if hasattr(args, "decoder_final_norm"): + args.no_decoder_final_norm = not args.decoder_final_norm + + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.activation_fn = getattr(args, "activation_fn", "relu") + + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + + args.add_bos_token = getattr(args, "add_bos_token", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.character_embeddings = getattr(args, "character_embeddings", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + # Model training is not stable without this + args.decoder_normalize_before = True + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) + + args.adaptive_input = getattr(args, "adaptive_input", False) + args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) + + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) + + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + + +@register_model_architecture("transformer_lm", "transformer_lm_big") +def transformer_lm_big(args): + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_wiki103") +@register_model_architecture("transformer_lm", "transformer_lm_baevski_wiki103") +def transformer_lm_baevski_wiki103(args): + args.decoder_layers = getattr(args, "decoder_layers", 16) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.dropout = getattr(args, "dropout", 0.3) + args.adaptive_input = getattr(args, "adaptive_input", True) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", True) + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", "20000,60000") + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "20000,60000" + ) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0.2) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.1) + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", True) + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", True) + transformer_lm_big(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gbw") +@register_model_architecture("transformer_lm", "transformer_lm_baevski_gbw") +def transformer_lm_baevski_gbw(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", True) + transformer_lm_big(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt") +def transformer_lm_gpt(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072) + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_small") +def transformer_lm_gpt2_small(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_medium") +def transformer_lm_gpt2_medium(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1280) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 5120) + args.decoder_layers = getattr(args, "decoder_layers", 36) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 20) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_big") +def transformer_lm_gpt2_big(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1600) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 6400) + args.decoder_layers = getattr(args, "decoder_layers", 48) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 25) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) diff --git a/fairseq/models/wav2vec/__init__.py b/fairseq/models/wav2vec/__init__.py new file mode 100644 index 0000000..631fd20 --- /dev/null +++ b/fairseq/models/wav2vec/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .wav2vec import * # noqa +from .wav2vec2 import * # noqa +from .wav2vec2_asr import * # noqa +from .unispeech import * diff --git a/fairseq/models/wav2vec/unispeech.py b/fairseq/models/wav2vec/unispeech.py new file mode 100644 index 0000000..d1dd27f --- /dev/null +++ b/fairseq/models/wav2vec/unispeech.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import checkpoint_utils, tasks, utils, pdb +from fairseq.models import ( + BaseFairseqModel, + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer +from fairseq.models.wav2vec.wav2vec2_s1 import Wav2Vec2Model + + +@register_model("unispeech") +class Unispeech(BaseFairseqModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + Wav2Vec2Model.add_args(parser) + parser.add_argument("--replace-prob", type=float, default=0.5) + + def __init__(self, w2v_encoder, args): + super().__init__() + self.w2v_encoder = w2v_encoder + self.args = args + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + base_architecture(args) + w2v_encoder = Wav2VecEncoder(args, task.target_dictionary, task.source_dictionary) + return cls(w2v_encoder, args) + + def get_normalized_probs(self, net_output, log_probs): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits = net_output["encoder_out"] + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def forward(self, **kwargs): + x = self.w2v_encoder(**kwargs) + return x + + # def max_positions(self): + # return None + + +class Wav2VecEncoder(FairseqEncoder): + def __init__(self, args, tgt_dict=None, src_dict=None): + self.apply_mask = args.apply_mask + + super().__init__(src_dict) + + self.w2v_model = Wav2Vec2Model.build_model(args) + + self.final_dropout = nn.Dropout(args.final_dropout) + self.freeze_finetune_updates = args.freeze_finetune_updates + self.num_updates = 0 + self.replace_prob = args.replace_prob + + d = args.encoder_embed_dim + self.encoder_dim = args.encoder_embed_dim + + if tgt_dict is not None: + self.proj = Linear(d, len(tgt_dict)) + elif getattr(args, 'decoder_embed_dim', d) != d: + self.proj = Linear(d, args.decoder_embed_dim) + else: + self.proj = None + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def forward(self, source, padding_mask, tbc=True, **kwargs): + w2v_args = { + "source": source, + "padding_mask": padding_mask, + "mask": self.apply_mask and self.training, + } + + contrastive_res = self.w2v_model(**w2v_args) + x = contrastive_res["features"] + padding_mask = contrastive_res["padding_mask"] + q = contrastive_res["q"] + + if tbc: + x = x.transpose(0, 1) + q = q.transpose(0, 1) + replace_mat = torch.empty(x.size(0), x.size(1)).fill_(self.replace_prob) + replace_mat = torch.bernoulli(replace_mat).bool().to(x.device) + replace_mat = replace_mat.unsqueeze(-1) + x = x.masked_fill(replace_mat, 0.0) + q.masked_fill(~replace_mat, 0.0) + + x = self.final_dropout(x) + + if self.proj: + x = self.proj(x) + + return { + "contrastive_res": contrastive_res, + "encoder_out": x, # T x B x C + "encoder_padding_mask": padding_mask, # B x T + "padding_mask": padding_mask, + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return None + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +@register_model_architecture("unispeech", "unispeech") +def base_architecture(args): + args.dropout_input = getattr(args, "dropout_input", 0) + args.final_dropout = getattr(args, "final_dropout", 0) + args.apply_mask = getattr(args, "apply_mask", True) + args.dropout = getattr(args, "dropout", 0) + args.attention_dropout = getattr(args, "attention_dropout", 0) + args.activation_dropout = getattr(args, "activation_dropout", 0) + + args.mask_length = getattr(args, "mask_length", 10) + args.mask_prob = getattr(args, "mask_prob", 0.5) + args.mask_selection = getattr(args, "mask_selection", "static") + args.mask_other = getattr(args, "mask_other", 0) + args.no_mask_overlap = getattr(args, "no_mask_overlap", False) + args.mask_channel_length = getattr(args, "mask_channel_length", 10) + args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) + args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") + args.mask_channel_other = getattr(args, "mask_channel_other", 0) + args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) + + args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0) + args.feature_grad_mult = getattr(args, "feature_grad_mult", 0) + args.layerdrop = getattr(args, "layerdrop", 0.0) + args.replace_prob = getattr(args, "replace_prob", 0.5) + + diff --git a/fairseq/models/wav2vec/wav2vec.py b/fairseq/models/wav2vec/wav2vec.py new file mode 100644 index 0000000..772995b --- /dev/null +++ b/fairseq/models/wav2vec/wav2vec.py @@ -0,0 +1,735 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq.models import BaseFairseqModel, register_model, register_model_architecture +from fairseq.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GumbelVectorQuantizer, + KmeansVectorQuantizer, + TransposeLast, +) +from fairseq.utils import buffered_arange + + +logger = logging.getLogger(__name__) + + +@register_model("wav2vec") +class Wav2VecModel(BaseFairseqModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--prediction-steps", + type=int, + metavar="N", + help="number of steps ahead to predict", + ) + parser.add_argument( + "--sample-distance", + type=int, + metavar="N", + help="sample distance from target. does not work properly with cross-sampling", + ) + parser.add_argument( + "--cross-sample-negatives", + type=int, + metavar="N", + help="num of cross sampled negatives", + ) + parser.add_argument( + "--num-negatives", type=int, metavar="N", help="number of negative examples" + ) + parser.add_argument( + "--conv-feature-layers", + type=str, + metavar="EXPR", + help="convolutional feature extraction layers [(dim, kernel_size, stride), ...]", + ) + parser.add_argument( + "--conv-aggregator-layers", + type=str, + metavar="EXPR", + help="convolutional feature extraction layers [(dim, kernel_size, stride), ...]", + ) + parser.add_argument( + "--dropout", + type=float, + metavar="D", + help="dropout to apply within the model", + ) + parser.add_argument( + "--dropout-features", + type=float, + metavar="D", + help="dropout to apply to the features", + ) + parser.add_argument( + "--dropout-agg", + type=float, + metavar="D", + help="dropout to apply after aggregation step", + ) + parser.add_argument( + "--encoder", type=str, choices=["cnn"], help="type of encoder to use" + ) + parser.add_argument( + "--aggregator", + type=str, + choices=["cnn", "gru"], + help="type of aggregator to use", + ) + parser.add_argument( + "--gru-dim", type=int, metavar="N", help="GRU dimensionality" + ) + + parser.add_argument( + "--no-conv-bias", + action="store_true", + help="if set, does not learn bias for conv layers", + ) + parser.add_argument( + "--agg-zero-pad", + action="store_true", + help="if set, zero pads in aggregator instead of repl pad", + ) + + parser.add_argument( + "--skip-connections-feat", + action="store_true", + help="if set, adds skip connections to the feature extractor", + ) + parser.add_argument( + "--skip-connections-agg", + action="store_true", + help="if set, adds skip connections to the aggregator", + ) + parser.add_argument( + "--residual-scale", + type=float, + metavar="D", + help="scales residual by sqrt(value)", + ) + + parser.add_argument( + "--log-compression", + action="store_true", + help="if set, adds a log compression to feature extractor", + ) + + parser.add_argument( + "--balanced-classes", + action="store_true", + help="if set, loss is scaled to balance for number of negatives", + ) + + parser.add_argument( + "--project-features", + choices=["none", "same", "new"], + help="if not none, features are projected using the (same or new) aggregator", + ) + + parser.add_argument( + "--non-affine-group-norm", + action="store_true", + help="if set, group norm is not affine", + ) + + parser.add_argument( + "--offset", + help="if set, introduces an offset from target to predictions. " + 'if set to "auto", it is computed automatically from the receptive field', + ) + + parser.add_argument( + "--activation", + type=str, + choices=["relu", "gelu"], + help="which activation function to use", + ) + + parser.add_argument( + "--vq-type", + type=str, + choices=["none", "gumbel", "kmeans"], + help="which type of quantizer to use", + ) + parser.add_argument( + "--vq-vars", + type=int, + metavar="N", + help="if set, project to this many vector quantized variables per group", + ) + parser.add_argument( + "--vq-groups", + type=int, + metavar="N", + help="number of groups of latent variables", + ) + parser.add_argument( + "--vq-dim", + type=int, + metavar="N", + help="uses this dimensionality for quantized vectors", + ) + parser.add_argument( + "--vq-depth", + type=int, + metavar="N", + help="number of layers for vq weight projection", + ) + parser.add_argument( + "--combine-groups", + action="store_true", + help="if set, variables are shared among groups", + ) + parser.add_argument( + "--vq-temp", + type=str, + metavar="TEMP", + help="temperature for latent variable sampling with gumbel softmax. should be a tuple of 3 values (start, end, decay)", + ) + parser.add_argument( + "--vq-gamma", + type=float, + metavar="D", + help="gamma parameter for kmeans style vector quantization", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_wav2vec_architecture(args) + + model = Wav2VecModel(args) + logger.info(model) + return model + + def __init__(self, args): + super().__init__() + + self.prediction_steps = args.prediction_steps + offset = args.offset + + if args.activation == "relu": + activation = nn.ReLU() + elif args.activation == "gelu": + activation = nn.GELU() + else: + raise Exception("unknown activation " + args.activation) + + if args.encoder == "cnn": + feature_enc_layers = eval(args.conv_feature_layers) + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + log_compression=args.log_compression, + skip_connections=args.skip_connections_feat, + residual_scale=args.residual_scale, + non_affine_group_norm=args.non_affine_group_norm, + activation=activation, + ) + embed = feature_enc_layers[-1][0] + else: + raise Exception("unknown encoder type " + args.encoder) + + self.vector_quantizer = None + if args.vq_type == "gumbel": + self.vector_quantizer = GumbelVectorQuantizer( + dim=embed, + num_vars=args.vq_vars, + temp=eval(args.vq_temp), + groups=args.vq_groups, + combine_groups=args.combine_groups, + vq_dim=args.vq_dim if args.vq_dim > 0 else embed, + time_first=False, + activation=activation, + weight_proj_depth=args.vq_depth, + weight_proj_factor=2, + ) + elif args.vq_type == "kmeans": + self.vector_quantizer = KmeansVectorQuantizer( + dim=embed, + num_vars=args.vq_vars, + groups=args.vq_groups, + combine_groups=args.combine_groups, + vq_dim=args.vq_dim if args.vq_dim > 0 else embed, + time_first=False, + gamma=args.vq_gamma, + ) + else: + assert ( + args.vq_type == "none" or args.vq_type is None + ), "Unknown quantizer type" + + if args.offset == "auto": + assert args.encoder == "cnn" + jin = 0 + rin = 0 + for _, k, stride in feature_enc_layers: + if rin == 0: + rin = k + rin = rin + (k - 1) * jin + if jin == 0: + jin = stride + else: + jin *= stride + offset = math.ceil(rin / jin) + + offset = int(offset) + + def make_aggregator(): + if args.aggregator == "cnn": + agg_layers = eval(args.conv_aggregator_layers) + agg_dim = agg_layers[-1][0] + feature_aggregator = ConvAggegator( + conv_layers=agg_layers, + embed=embed, + dropout=args.dropout, + skip_connections=args.skip_connections_agg, + residual_scale=args.residual_scale, + non_affine_group_norm=args.non_affine_group_norm, + conv_bias=not args.no_conv_bias, + zero_pad=args.agg_zero_pad, + activation=activation, + ) + elif args.aggregator == "gru": + agg_dim = args.gru_dim + feature_aggregator = nn.Sequential( + TransposeLast(), + nn.GRU( + input_size=embed, + hidden_size=agg_dim, + num_layers=1, + dropout=args.dropout, + ), + TransposeLast(deconstruct_idx=0), + ) + else: + raise Exception("unknown aggregator type " + args.aggregator) + + return feature_aggregator, agg_dim + + self.feature_aggregator, agg_dim = make_aggregator() + + self.wav2vec_predictions = Wav2VecPredictionsModel( + in_dim=agg_dim, + out_dim=embed, + prediction_steps=args.prediction_steps, + n_negatives=args.num_negatives, + cross_sample_negatives=args.cross_sample_negatives, + sample_distance=args.sample_distance, + dropout=args.dropout, + offset=offset, + balanced_classes=args.balanced_classes, + infonce=args.infonce, + ) + + self.dropout_feats = nn.Dropout(p=args.dropout_features) + self.dropout_agg = nn.Dropout(p=args.dropout_agg) + + if args.project_features == "none": + self.project_features = None + elif args.project_features == "same": + self.project_features = self.feature_aggregator + elif args.project_features == "new": + self.project_features, _ = make_aggregator() + + def forward(self, source): + result = {} + + features = self.feature_extractor(source) + if self.vector_quantizer: + q_res = self.vector_quantizer(features) + features = q_res["x"] + for k in q_res.keys(): + if k != "x": + result[k] = q_res[k] + + x = self.dropout_feats(features) + x = self.feature_aggregator(x) + x = self.dropout_agg(x) + + if self.project_features is not None: + features = self.project_features(features) + x, targets = self.wav2vec_predictions(x, features) + result["cpc_logits"] = x + result["cpc_targets"] = targets + + return result + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + + def max_positions(self): + """Maximum length supported by the model.""" + return sys.maxsize + + def get_logits(self, net_output): + logits = net_output["cpc_logits"] + return logits + + def get_targets(self, sample, net_output): + t = net_output["cpc_targets"] + if isinstance(t, tuple): + t = t[0] + return t.contiguous() + + def get_target_weights(self, targets, net_output): + targets = net_output["cpc_targets"] + if isinstance(targets, tuple) and targets[-1] is not None: + return targets[-1] + return None + + def get_extra_losses(self, net_output): + loss = None + if "prob_perplexity" in net_output: + loss = net_output["num_vars"] - net_output["prob_perplexity"] + elif "kmeans_loss" in net_output: + loss = net_output["kmeans_loss"] + + return loss + + +def norm_block(is_layer_norm, dim, affine=True): + if is_layer_norm: + mod = nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=affine), + TransposeLast(), + ) + else: + mod = Fp32GroupNorm(1, dim, affine=affine) + + return mod + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers, + dropout, + log_compression, + skip_connections, + residual_scale, + non_affine_group_norm, + activation, + ): + super().__init__() + + def block(n_in, n_out, k, stride): + return nn.Sequential( + nn.Conv1d(n_in, n_out, k, stride=stride, bias=False), + nn.Dropout(p=dropout), + norm_block( + is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm + ), + activation, + ) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for dim, k, stride in conv_layers: + self.conv_layers.append(block(in_d, dim, k, stride)) + in_d = dim + + self.log_compression = log_compression + self.skip_connections = skip_connections + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x): + # BxT -> BxCxT + x = x.unsqueeze(1) + + for conv in self.conv_layers: + residual = x + x = conv(x) + if self.skip_connections and x.size(1) == residual.size(1): + tsz = x.size(2) + r_tsz = residual.size(2) + residual = residual[..., :: r_tsz // tsz][..., :tsz] + x = (x + residual) * self.residual_scale + + if self.log_compression: + x = x.abs() + x = x + 1 + x = x.log() + + return x + + +class ZeroPad1d(nn.Module): + def __init__(self, pad_left, pad_right): + super().__init__() + self.pad_left = pad_left + self.pad_right = pad_right + + def forward(self, x): + return F.pad(x, (self.pad_left, self.pad_right)) + + +class ConvAggegator(nn.Module): + def __init__( + self, + conv_layers, + embed, + dropout, + skip_connections, + residual_scale, + non_affine_group_norm, + conv_bias, + zero_pad, + activation, + ): + super().__init__() + + def block(n_in, n_out, k, stride): + # padding dims only really make sense for stride = 1 + ka = k // 2 + kb = ka - 1 if k % 2 == 0 else ka + + pad = ( + ZeroPad1d(ka + kb, 0) if zero_pad else nn.ReplicationPad1d((ka + kb, 0)) + ) + + return nn.Sequential( + pad, + nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias), + nn.Dropout(p=dropout), + norm_block(False, n_out, affine=not non_affine_group_norm), + activation, + ) + + in_d = embed + self.conv_layers = nn.ModuleList() + self.residual_proj = nn.ModuleList() + for dim, k, stride in conv_layers: + if in_d != dim and skip_connections: + self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False)) + else: + self.residual_proj.append(None) + + self.conv_layers.append(block(in_d, dim, k, stride)) + in_d = dim + self.conv_layers = nn.Sequential(*self.conv_layers) + self.skip_connections = skip_connections + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x): + for rproj, conv in zip(self.residual_proj, self.conv_layers): + residual = x + x = conv(x) + if self.skip_connections: + if rproj is not None: + residual = rproj(residual) + x = (x + residual) * self.residual_scale + return x + + +class Wav2VecPredictionsModel(nn.Module): + def __init__( + self, + in_dim, + out_dim, + prediction_steps, + n_negatives, + cross_sample_negatives, + sample_distance, + dropout, + offset, + balanced_classes, + infonce, + ): + super().__init__() + + self.n_negatives = n_negatives + self.cross_sample_negatives = cross_sample_negatives + self.sample_distance = sample_distance + self.project_to_steps = nn.ConvTranspose2d( + in_dim, out_dim, (1, prediction_steps) + ) + self.dropout = nn.Dropout(p=dropout) + self.offset = offset + self.balanced_classes = balanced_classes + self.infonce = infonce + + def sample_negatives(self, y): + bsz, fsz, tsz = y.shape + + y = y.transpose(0, 1) # BCT -> CBT + y = y.contiguous().view(fsz, -1) # CBT => C(BxT) + + cross_high = tsz * bsz + high = tsz if self.sample_distance is None else min(tsz, self.sample_distance) + assert high > 1 + + neg_idxs = torch.randint(low=0, high=high, size=(bsz, self.n_negatives * tsz)) + + with torch.no_grad(): + if self.n_negatives > 0: + tszs = ( + buffered_arange(tsz) + .unsqueeze(-1) + .expand(-1, self.n_negatives) + .flatten() + ) + + neg_idxs = torch.randint( + low=0, high=high - 1, size=(bsz, self.n_negatives * tsz) + ) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = ( + buffered_arange(tsz) + .unsqueeze(-1) + .expand(-1, self.cross_sample_negatives) + .flatten() + ) + + cross_neg_idxs = torch.randint( + low=0, + high=cross_high - 1, + size=(bsz, self.cross_sample_negatives * tsz), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[..., neg_idxs.view(-1)] + negs = negs.view( + fsz, bsz, self.n_negatives + self.cross_sample_negatives, tsz + ).permute( + 2, 1, 0, 3 + ) # to NxBxCxT + + return negs + + def forward(self, x, y): + + x = x.unsqueeze(-1) + x = self.project_to_steps(x) # BxCxTxS + x = self.dropout(x) + + negatives = self.sample_negatives(y) + y = y.unsqueeze(0) + targets = torch.cat([y, negatives], dim=0) # Copies x B x C x T + + copies = targets.size(0) + bsz, dim, tsz, steps = x.shape + steps = min(steps, tsz - self.offset) + + predictions = x.new( + bsz * copies * (tsz - self.offset + 1) * steps + - ((steps + 1) * steps // 2) * copies * bsz + ) + if self.infonce: + labels = predictions.new_full( + (predictions.shape[0] // copies,), 0, dtype=torch.long + ) + else: + labels = torch.zeros_like(predictions) + weights = ( + torch.full_like(labels, 1 / self.n_negatives) + if self.balanced_classes and not self.infonce + else None + ) + + start = end = 0 + for i in range(steps): + offset = i + self.offset + end = start + (tsz - offset) * bsz * copies + if self.infonce: + predictions[start:end] = torch.einsum( + "bct,nbct->tbn", x[..., :-offset, i], targets[..., offset:] + ).flatten() + else: + pos_num = (end - start) // copies + predictions[start:end] = torch.einsum( + "bct,nbct->nbt", x[..., :-offset, i], targets[..., offset:] + ).flatten() + labels[start : start + pos_num] = 1.0 + if weights is not None: + weights[start : start + pos_num] = 1.0 + start = end + assert end == predictions.numel(), "{} != {}".format(end, predictions.numel()) + + if self.infonce: + predictions = predictions.view(-1, copies) + else: + if weights is not None: + labels = (labels, weights) + + return predictions, labels + + +@register_model_architecture("wav2vec", "wav2vec") +def base_wav2vec_architecture(args): + conv_feature_layers = "[(512, 10, 5)]" + conv_feature_layers += " + [(512, 8, 4)]" + conv_feature_layers += " + [(512, 4, 2)] * 3" + args.conv_feature_layers = getattr(args, "conv_feature_layers", conv_feature_layers) + + args.conv_aggregator_layers = getattr( + args, "conv_aggregator_layers", "[(512, 3, 1)] * 9" + ) + + args.prediction_steps = getattr(args, "prediction_steps", 12) + args.num_negatives = getattr(args, "num_negatives", 1) + args.sample_distance = getattr(args, "sample_distance", None) + args.cross_sample_negatives = getattr(args, "cross_sample_negatives", 0) + + args.dropout = getattr(args, "dropout", 0.0) + args.dropout_features = getattr(args, "dropout_features", 0.0) + args.dropout_agg = getattr(args, "dropout_agg", 0.0) + args.encoder = getattr(args, "encoder", "cnn") + args.aggregator = getattr(args, "aggregator", "cnn") + + args.skip_connections_feat = getattr(args, "skip_connections_feat", False) + args.skip_connections_agg = getattr(args, "skip_connections_agg", False) + args.residual_scale = getattr(args, "residual_scale", 0.5) + + args.gru_dim = getattr(args, "gru_dim", 512) + + args.no_conv_bias = getattr(args, "no_conv_bias", False) + args.agg_zero_pad = getattr(args, "agg_zero_pad", False) + + args.log_compression = getattr(args, "log_compression", False) + + args.balanced_classes = getattr(args, "balanced_classes", False) + args.infonce = getattr(args, "infonce", False) + args.project_features = getattr(args, "project_features", "none") + + args.non_affine_group_norm = getattr(args, "non_affine_group_norm", False) + + args.offset = getattr(args, "offset", "auto") + + args.activation = getattr(args, "activation", "relu") + + args.vq_type = getattr(args, "vq_type", "none") + args.vq_vars = getattr(args, "vq_vars", 320) + args.vq_groups = getattr(args, "vq_groups", 2) + args.vq_dim = getattr(args, "vq_dim", 0) + args.vq_depth = getattr(args, "vq_depth", 1) + args.combine_groups = getattr(args, "combine_groups", False) + args.vq_temp = getattr(args, "vq_temp", "(2.0, 0.5, 0.999995)") + args.vq_gamma = getattr(args, "vq_gamma", 0.25) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py new file mode 100644 index 0000000..a33a2b4 --- /dev/null +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -0,0 +1,1047 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils, checkpoint_utils, pdb +from fairseq.data.data_utils import compute_mask_indices +from fairseq.models import BaseFairseqModel, register_model, register_model_architecture +from fairseq.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + GumbelVectorQuantizer, + LayerNorm, + MultiheadAttention, + SamePad, + TransposeLast, +) +from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.utils import buffered_arange + + +@register_model("wav2vec2") +class Wav2Vec2Model(BaseFairseqModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + + parser.add_argument("--pretrained-path", help="path to pretrained model", default=None) + parser.add_argument( + "--extractor-mode", + choices=["default", "layer_norm"], + help="mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with --normalize)", + ) + + parser.add_argument( + "--encoder-layers", + type=int, + metavar="L", + help="num encoder layers in the transformer", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="H", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="F", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="A", + help="num encoder attention heads", + ) + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + + parser.add_argument( + "--dropout", + type=float, + metavar="D", + help="dropout probability for the transformer", + ) + + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + + parser.add_argument( + "--activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN", + ) + + parser.add_argument( + "--final-dim", + type=int, + metavar="D", + help="project final representations and targets to this many dimensions", + ) + + parser.add_argument( + "--layer-norm-first", + action="store_true", + help="apply layernorm first in the transformer", + ) + + parser.add_argument( + "--encoder-layerdrop", + type=float, + help="probability of dropping a tarnsformer layer", + ) + + parser.add_argument( + "--conv-feature-layers", + type=str, + metavar="EXPR", + help="convolutional feature extraction layers [(dim, kernel_size, stride), ...]", + ) + + parser.add_argument( + "--logit-temp", type=float, help="temperature to divide logits by" + ) + + parser.add_argument( + "--quantize-targets", action="store_true", help="use quantized targets" + ) + + parser.add_argument( + "--quantize-input", action="store_true", help="use quantized inputs" + ) + + parser.add_argument( + "--same-quantizer", + action="store_true", + help="use same quantizer for inputs and targets", + ) + + parser.add_argument( + "--feature-grad-mult", + type=float, + help="multiply feature extractor var grads by this", + ) + + parser.add_argument( + "--latent-vars", + type=int, + metavar="N", + help="number of latent variables V in each group of the codebook", + ) + + parser.add_argument( + "--latent-groups", + type=int, + metavar="N", + help="number of groups G of latent variables in the codebook", + ) + + parser.add_argument( + "--latent-dim", + type=int, + metavar="N", + help="if set, uses this dimensionality for latent variables. otherwise uses final_dim / latent_groups", + ) + + parser.add_argument("--mask-length", type=int, help="mask length") + + parser.add_argument( + "--mask-prob", type=float, help="probability of replacing a token with mask" + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-other", + type=float, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indices", + ) + + parser.add_argument( + "--no-mask-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-min-space", + type=int, + help="min space between spans (if no overlap is enabled)", + ) + + parser.add_argument( + "--mask-channel-length", + type=int, + help="repeat the mask indices multiple times", + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indices", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-channel-min-space", + type=int, + help="min space between spans (if no overlap is enabled)", + ) + + parser.add_argument( + "--dropout-input", + type=float, + metavar="D", + help="dropout to apply to the input (after feat extr)", + ) + + parser.add_argument( + "--dropout-features", + type=float, + metavar="D", + help="dropout to apply to the features (after feat extr)", + ) + + parser.add_argument( + "--num-negatives", type=int, metavar="N", help="number of negative examples" + ) + + parser.add_argument( + "--negatives-from-everywhere", + action="store_true", + help="sample negatives from everywhere, not just masked states", + ) + + parser.add_argument( + "--cross-sample-negatives", + type=int, + metavar="N", + help="num of cross sampled negatives", + ) + + parser.add_argument( + "--codebook-negatives", + type=int, + metavar="N", + help="num of codebook sampled negatives", + ) + + parser.add_argument( + "--conv-pos", + type=int, + metavar="N", + help="number of filters for convolutional positional embeddings", + ) + + parser.add_argument( + "--conv-pos-groups", + type=int, + metavar="N", + help="number of groups for convolutional positional embedding", + ) + + parser.add_argument( + "--latent-temp", + type=str, + metavar="D", + help="temperature for latent variable sampling. can be tuple of 3 values (start, end, decay)", + ) + + parser.add_argument( + "--target-glu", action="store_true", help="adds projection + glu to targets" + ) + + parser.add_argument( + "--conv-bias", action="store_true", help="include bias in conv encoder" + ) + + def __init__(self, args): + super().__init__() + self.args = args + + feature_enc_layers = eval(args.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=args.extractor_mode, + conv_bias=args.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, args.encoder_embed_dim) + if self.embed != args.encoder_embed_dim and not args.quantize_input + else None + ) + + self.mask_prob = args.mask_prob + self.mask_selection = args.mask_selection + self.mask_other = args.mask_other + self.mask_length = args.mask_length + self.no_mask_overlap = args.no_mask_overlap + self.mask_min_space = args.mask_min_space + + self.mask_channel_prob = args.mask_channel_prob + self.mask_channel_selection = args.mask_channel_selection + self.mask_channel_other = args.mask_channel_other + self.mask_channel_length = args.mask_channel_length + self.no_mask_channel_overlap = args.no_mask_channel_overlap + self.mask_channel_min_space = args.mask_channel_min_space + + self.dropout_input = nn.Dropout(args.dropout_input) + self.dropout_features = nn.Dropout(args.dropout_features) + + self.feature_grad_mult = args.feature_grad_mult + + self.quantizer = None + self.input_quantizer = None + + self.n_negatives = args.num_negatives + self.cross_sample_negatives = args.cross_sample_negatives + self.codebook_negatives = args.codebook_negatives + self.negatives_from_everywhere = args.negatives_from_everywhere + + self.logit_temp = args.logit_temp + + final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim + + if args.quantize_targets: + vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim + self.quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=args.latent_vars, + temp=eval(args.latent_temp), + groups=args.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_q = nn.Linear(vq_dim, final_dim) + else: + self.project_q = nn.Linear(self.embed, final_dim) + + if args.quantize_input: + if args.same_quantizer and self.quantizer is not None: + vq_dim = final_dim + self.input_quantizer = self.quantizer + else: + vq_dim = ( + args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim + ) + self.input_quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=args.latent_vars, + temp=eval(args.latent_temp), + groups=args.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) + + self.mask_emb = nn.Parameter( + torch.FloatTensor(args.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(args) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if args.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) + if getattr(args, 'pretrained_path', None) is not None: + state = checkpoint_utils.load_checkpoint_to_cpu( + args.pretrained_path)['model'] + from collections import OrderedDict + new_state = OrderedDict() + model_state = self.state_dict() + for k, v in state.items(): + if 'w2v_encoder.w2v_model.' in k: + k = k.replace('w2v_encoder.w2v_model.', '') + if k in model_state and model_state[k].size() == v.size(): + new_state[k] = v + logging.info('load parameter {} in Wav2vec2 from path'.format(k, args.pretrained_path)) + self.load_state_dict(new_state, strict=False) + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + @classmethod + def build_model(cls, args, task=None): + """Build a new model instance.""" + + # make sure all arguments are present + base_architecture(args) + + return cls(args) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def sample_negatives(self, y, num): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + y = y.view(-1, fsz) # BTC => (BxT)C + + cross_high = tsz * bsz + high = tsz + with torch.no_grad(): + assert high > 1, f"{bsz,tsz,fsz}" + + if self.n_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.n_negatives) + .flatten() + ) + + neg_idxs = torch.randint( + low=0, high=high - 1, size=(bsz, self.n_negatives * num) + ) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.cross_sample_negatives) + .flatten() + ) + + cross_neg_idxs = torch.randint( + low=0, + high=cross_high - 1, + size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view( + bsz, num, self.n_negatives + self.cross_sample_negatives, fsz + ).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def compute_preds(self, x, y, negatives): + + neg_is_pos = (y == negatives).all(-1) + y = y.unsqueeze(0) + targets = torch.cat([y, negatives], dim=0) + + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) + + logits /= self.logit_temp + + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + + return logits + + def forward(self, source, padding_mask=None, mask=True, features_only=False): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + padding_mask = padding_mask.all(-1) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + num_vars = None + code_ppl = None + prob_ppl = None + curr_temp = None + + if self.input_quantizer: + q = self.input_quantizer(features, produce_targets=False) + features = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + features = self.project_inp(features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask) + if mask_indices is not None: + y = unmasked_features[mask_indices].view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + y = unmasked_features + else: + x = features + y = unmasked_features + mask_indices = None + + x = self.encoder(x, padding_mask=padding_mask) + + if features_only: + return {"x": x, "padding_mask": padding_mask} + + results = {"features": x, "feautre_padding_mask": padding_mask} + + if self.quantizer: + q = self.quantizer(y, produce_targets=False) + y = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + + y = self.project_q(y) + + if self.negatives_from_everywhere: + neg_cands = self.quantizer(unmasked_features, produce_targets=False) + negs = self.project_q(neg_cands['x']) + negs, _ = self.sample_negatives(negs, y.size(1)) + + else: + negs, _ = self.sample_negatives(y, y.size(1)) + + if self.codebook_negatives > 0: + cb_negs = self.quantizer.sample_from_codebook( + y.size(0) * y.size(1), self.codebook_negatives + ) + cb_negs = cb_negs.view( + self.codebook_negatives, y.size(0), y.size(1), -1 + ) # order doesnt matter + cb_negs = self.project_q(cb_negs) + negs = torch.cat([negs, cb_negs], dim=0) + else: + y = self.project_q(y) + + if self.negatives_from_everywhere: + negs, _ = self.sample_negatives(unmasked_features, y.size(1)) + negs = self.project_q(negs) + else: + negs, _ = self.sample_negatives(y, y.size(1)) + + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + + x = self.final_proj(x) + x = self.compute_preds(x, y, negs) + + + results['x'] = x + results['padding_mask'] = padding_mask + results['features_pen'] = features_pen + if prob_ppl is not None: + results["prob_perplexity"] = prob_ppl + results["code_perplexity"] = code_ppl + results["num_vars"] = num_vars + results["temp"] = curr_temp + + return results + + def quantize(self, x): + assert self.quantizer is not None + x = self.feature_extractor(x) + x = x.transpose(1, 2) + x = self.layer_norm(x) + return self.quantizer.forward_idx(x) + + def extract_features(self, source, padding_mask, mask=False): + res = self.forward(source, padding_mask, mask=mask, features_only=True) + return res["x"], res["padding_mask"] + + def get_logits(self, net_output): + logits = net_output["x"] + logits = logits.transpose(0, 2) + logits = logits.reshape(-1, logits.size(-1)) + return logits + + def get_targets(self, sample, net_output, expand_steps=True): + x = net_output["x"] + return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) + + def get_extra_losses(self, net_output): + pen = [] + + if "prob_perplexity" in net_output: + pen.append( + (net_output["num_vars"] - net_output["prob_perplexity"]) + / net_output["num_vars"] + ) + + if "features_pen" in net_output: + pen.append(net_output["features_pen"]) + + return pen + + def remove_pretraining_modules(self): + self.quantizer = None + self.project_q = None + self.target_glu = None + self.final_proj = None + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + + def forward(self, x): + + # BxT -> BxCxT + x = x.unsqueeze(1) + + for conv in self.conv_layers: + x = conv(x) + + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + for _ in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None): + x = self.extract_features(x, padding_mask) + + if self.layer_norm_first: + x = self.layer_norm(x) + + return x + + def extract_features(self, x, padding_mask=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x += x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) + layer_results.append(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x + + def max_positions(self): + """Maximum output length supported by the encoder.""" + return self.args.max_positions + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_fn = utils.get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn + + +@register_model_architecture("wav2vec2", "wav2vec2") +def base_architecture(args): + args.extractor_mode = getattr(args, "extractor_mode", "default") + + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + + args.final_dim = getattr(args, "final_dim", 0) + + args.layer_norm_first = getattr(args, "layer_norm_first", False) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + + conv_feature_layers = "[(512, 10, 5)]" + conv_feature_layers += " + [(512, 8, 4)]" + conv_feature_layers += " + [(512, 4, 2)] * 3" + conv_feature_layers += " + [(512, 1, 1)]" + args.conv_feature_layers = getattr(args, "conv_feature_layers", conv_feature_layers) + + args.logit_temp = getattr(args, "logit_temp", 0.1) + + args.quantize_targets = getattr(args, "quantize_targets", False) + args.quantize_input = getattr(args, "quantize_input", False) + args.same_quantizer = getattr(args, "same_quantizer", False) + + args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) + + args.latent_vars = getattr(args, "latent_vars", 320) + args.latent_groups = getattr(args, "latent_groups", 2) + args.latent_dim = getattr(args, "latent_dim", 0) + + args.mask_length = getattr(args, "mask_length", 10) + args.mask_prob = getattr(args, "mask_prob", 0.65) + args.mask_selection = getattr(args, "mask_selection", "static") + args.mask_other = getattr(args, "mask_other", 0) + args.no_mask_overlap = getattr(args, "no_mask_overlap", False) + args.mask_min_space = getattr(args, "mask_min_space", 1) + + args.mask_channel_length = getattr(args, "mask_channel_length", 10) + args.mask_channel_prob = getattr(args, "mask_channel_prob", 0) + args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") + args.mask_channel_other = getattr(args, "mask_channel_other", 0) + args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) + args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1) + + args.dropout_input = getattr(args, "dropout_input", 0) + args.dropout_features = getattr(args, "dropout_features", 0) + + args.num_negatives = getattr(args, "num_negatives", 100) + args.negatives_from_everywhere = getattr(args, "negatives_from_everywhere", False) + args.cross_sample_negatives = getattr(args, "cross_sample_negatives", 0) + args.codebook_negatives = getattr(args, "codebook_negatives", 0) + + args.conv_pos = getattr(args, "conv_pos", 128) + args.conv_pos_groups = getattr(args, "conv_pos_groups", 16) + + args.latent_temp = getattr(args, "latent_temp", "(2,0.5,0.999995)") + + args.target_glu = getattr(args, "target_glu", False) + + args.conv_bias = getattr(args, "conv_bias", False) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py new file mode 100644 index 0000000..7301c34 --- /dev/null +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -0,0 +1,698 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import contextlib +import copy +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import checkpoint_utils, tasks, utils, pdb +from fairseq.models import ( + BaseFairseqModel, + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer + + +def add_common_args(parser): + parser.add_argument("--w2v-path", help="path to wav2vec 2.0 model") + parser.add_argument( + "--no-pretrained-weights", + action="store_true", + help="if true, does not load pretrained weights", + ) + parser.add_argument( + "--dropout-input", + type=float, + metavar="D", + help="dropout to apply to the input (after feat extr)", + ) + parser.add_argument( + "--final-dropout", + type=float, + metavar="D", + help="dropout after transformer and before final projection", + ) + parser.add_argument( + "--apply-mask", action="store_true", help="apply masking during fine-tuning" + ) + parser.add_argument( + "--dropout", + type=float, + metavar="D", + help="dropout probability inside wav2vec 2.0 model", + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights inside wav2vec 2.0 model", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN inside wav2vec 2.0 model", + ) + + parser.add_argument( + "--mask-length", type=int, help="repeat the mask indices multiple times" + ) + + parser.add_argument( + "--mask-prob", type=float, help="probability of replacing a token with mask" + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + + parser.add_argument( + "--no-mask-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-channel-length", type=int, help="repeat the mask indices multiple times" + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--freeze-finetune-updates", + default=0, + type=int, + help="dont finetune wav2vec for this many updates", + ) + + parser.add_argument( + "--feature-grad-mult", + default=None, + type=float, + help="reset feature grad mult in wav2vec 2.0 to this", + ) + + parser.add_argument( + "--layerdrop", + default=0.0, + type=float, + help="probability of dropping a layer in wav2vec 2.0", + ) + + + +@register_model("wav2vec_ctc") +class Wav2VecCtc(BaseFairseqModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + add_common_args(parser) + + def __init__(self, w2v_encoder, args): + super().__init__() + self.w2v_encoder = w2v_encoder + self.args = args + if getattr(args, 'w2v_path', None) is not None and not args.no_pretrained_weights: + state = checkpoint_utils.load_checkpoint_to_cpu( + args.w2v_path)['model'] + from collections import OrderedDict + new_state = OrderedDict() + model_state = self.state_dict() + for k, v in state.items(): + if k in model_state and model_state[k].size() == v.size(): + new_state[k] = v + logging.info('load parameter {} from pretrained model'.format(k)) + self.load_state_dict(new_state, strict=False) + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + base_architecture(args) + w2v_encoder = Wav2VecEncoder(args, task.target_dictionary) + return cls(w2v_encoder, args) + + def get_normalized_probs(self, net_output, log_probs): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits = net_output["encoder_out"] + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def forward(self, **kwargs): + x = self.w2v_encoder(**kwargs) + return x + + # def max_positions(self): + # return None + + +@register_model("wav2vec_seq2seq") +class TransformerModel(FairseqEncoderDecoderModel): + def __init__(self, args, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + add_common_args(parser) + + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-layerdrop", + type=float, + metavar="D", + help="decoder layerdrop chance", + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the decoder", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--no-token-positional-embeddings", + default=False, + action="store_true", + help="if set, disables positional embeddings (outside self attention)", + ) + + parser.add_argument( + "--decoder-dropout", + type=float, + metavar="D", + help="dropout probability in the decoder", + ) + parser.add_argument( + "--decoder-attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights inside the decoder", + ) + parser.add_argument( + "--decoder-activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN inside the decoder", + ) + + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if not hasattr(args, "max_source_positions"): + args.max_source_positions = 2048 + if not hasattr(args, "max_target_positions"): + args.max_target_positions = 2048 + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx) + return emb + + decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim) + + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + return TransformerModel(args, encoder, decoder) + + @classmethod + def build_encoder(cls, args): + return Wav2VecEncoder(args) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return TransformerDecoder(args, tgt_dict, embed_tokens) + + def forward(self, **kwargs): + encoder_out = self.encoder(tbc=False, **kwargs) + decoder_out = self.decoder(encoder_out=encoder_out, **kwargs) + return decoder_out + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + +class Wav2VecEncoder(FairseqEncoder): + def __init__(self, args, tgt_dict=None): + self.apply_mask = args.apply_mask + + if getattr(args, "w2v_args", None) is None: + state = checkpoint_utils.load_checkpoint_to_cpu( + args.w2v_path + ) + w2v_args = state.get("args", None) or state["cfg"].model + else: + state = None + w2v_args = args.w2v_args + + w2v_args.dropout = args.dropout + w2v_args.activation_dropout = args.activation_dropout + w2v_args.dropout_input = args.dropout_input + w2v_args.attention_dropout = args.attention_dropout + w2v_args.mask_length = args.mask_length + w2v_args.mask_prob = args.mask_prob + w2v_args.mask_selection = args.mask_selection + w2v_args.mask_other = args.mask_other + w2v_args.no_mask_overlap = args.no_mask_overlap + w2v_args.mask_channel_length = args.mask_channel_length + w2v_args.mask_channel_prob = args.mask_channel_prob + w2v_args.mask_channel_selection = args.mask_channel_selection + w2v_args.mask_channel_other = args.mask_channel_other + w2v_args.no_mask_channel_overlap = args.no_mask_channel_overlap + w2v_args.layerdrop = args.layerdrop + w2v_args.feature_grad_mult = args.feature_grad_mult + + + assert ( + args.normalize == w2v_args.normalize + ), "Fine-tuning works best when data normalization is the same" + + w2v_args.data = args.data + task = tasks.setup_task(w2v_args) + model = task.build_model(w2v_args) + if hasattr(model, 'w2v_encoder'): + model.w2v_encoder.w2v_model.remove_pretraining_modules() + + + if state is not None and not args.no_pretrained_weights: + model.load_state_dict(state["model"], strict=True) + + #model.remove_pretraining_modules() + #model = model.w2v_encoder.w2v_model + + + super().__init__(task.source_dictionary) + + if hasattr(model, 'w2v_encoder'): + self.encoder_dim = model.w2v_encoder.encoder_dim + self.w2v_model = model.w2v_encoder.w2v_model + else: + self.encoder_dim = w2v_args.encoder_embed_dim + self.w2v_model = model + #self.w2v_model.remove_pretraining_modules() + + + self.final_dropout = nn.Dropout(args.final_dropout) + self.freeze_finetune_updates = args.freeze_finetune_updates + self.num_updates = 0 + + if tgt_dict is not None: + self.proj = Linear(self.encoder_dim, len(tgt_dict)) + elif getattr(args, "decoder_embed_dim", self.encoder_dim) != self.encoder_dim: + self.proj = Linear(self.encoder_dim, args.decoder_embed_dim) + else: + self.proj = None + + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def forward(self, source, padding_mask, tbc=True, **kwargs): + + w2v_args = { + "source": source, + "padding_mask": padding_mask, + "mask": self.apply_mask and self.training, + } + + ft = self.freeze_finetune_updates <= self.num_updates + + with torch.no_grad() if not ft else contextlib.ExitStack(): + x, padding_mask = self.w2v_model.extract_features(**w2v_args) + + if tbc: + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + x = self.final_dropout(x) + + if self.proj: + x = self.proj(x) + + return { + "encoder_out": x, # T x B x C + "encoder_padding_mask": padding_mask, # B x T + "padding_mask": padding_mask, + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return None + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + super().__init__(dictionary) + + self.dropout = args.decoder_dropout + self.share_input_output_embed = args.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = args.decoder_embed_dim + self.output_embed_dim = args.decoder_embed_dim + args.encoder_embed_dim = embed_dim + + self.layerdrop = args.decoder_layerdrop + + padding_idx = embed_tokens.padding_idx + self.max_target_positions = args.max_target_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + args.max_target_positions, + embed_dim, + padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + + args = copy.deepcopy(args) + args.dropout = args.decoder_dropout + args.attention_dropout = args.decoder_attention_dropout + args.activation_dropout = args.decoder_activation_dropout + + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + ) + + if not self.share_input_output_embed: + self.embed_out = nn.Parameter( + torch.Tensor(len(dictionary), self.output_embed_dim) + ) + nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5) + + if args.decoder_normalize_before and not getattr( + args, "no_decoder_final_norm", False + ): + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + prev_output_tokens = prev_output_tokens.long() + x, extra = self.extract_features( + prev_output_tokens, encoder_out, incremental_state + ) + x = self.output_layer(x) + return x, extra + + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + + inner_states = [x] + + # decoder layers + for layer in self.layers: + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, attn, _ = layer( + x, + encoder_out["encoder_out"] if encoder_out is not None else None, + encoder_out["encoder_padding_mask"] + if encoder_out is not None + else None, + incremental_state, + self_attn_mask=self.buffered_future_mask(x) + if incremental_state is None + else None, + ) + inner_states.append(x) + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, {"attn": attn, "inner_states": inner_states} + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + # project back to size of vocabulary + if self.share_input_output_embed: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_out) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +@register_model_architecture("wav2vec_ctc", "wav2vec_ctc") +def base_architecture(args): + args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False) + args.dropout_input = getattr(args, "dropout_input", 0) + args.final_dropout = getattr(args, "final_dropout", 0) + args.apply_mask = getattr(args, "apply_mask", False) + args.dropout = getattr(args, "dropout", 0) + args.attention_dropout = getattr(args, "attention_dropout", 0) + args.activation_dropout = getattr(args, "activation_dropout", 0) + + args.mask_length = getattr(args, "mask_length", 10) + args.mask_prob = getattr(args, "mask_prob", 0.5) + args.mask_selection = getattr(args, "mask_selection", "static") + args.mask_other = getattr(args, "mask_other", 0) + args.no_mask_overlap = getattr(args, "no_mask_overlap", False) + args.mask_channel_length = getattr(args, "mask_channel_length", 10) + args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) + args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") + args.mask_channel_other = getattr(args, "mask_channel_other", 0) + args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) + + args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0) + args.feature_grad_mult = getattr(args, "feature_grad_mult", 0) + args.layerdrop = getattr(args, "layerdrop", 0.0) + + +@register_model_architecture("wav2vec_seq2seq", "wav2vec_seq2seq") +def seq2seq_architecture(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_layers = getattr(args, "decoder_layers", 10) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.decoder_dropout = getattr(args, "decoder_dropout", 0) + args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0) + args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + + base_architecture(args) diff --git a/fairseq/models/wav2vec/wav2vec2_s1.py b/fairseq/models/wav2vec/wav2vec2_s1.py new file mode 100644 index 0000000..7e42ed1 --- /dev/null +++ b/fairseq/models/wav2vec/wav2vec2_s1.py @@ -0,0 +1,1067 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils, checkpoint_utils, pdb +from fairseq.data.data_utils import compute_mask_indices +from fairseq.models import BaseFairseqModel, register_model, register_model_architecture +from fairseq.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + GumbelVectorQuantizer, + LayerNorm, + MultiheadAttention, + SamePad, + TransposeLast, +) +from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.utils import buffered_arange + + +@register_model("wav2vec2_s1") +class Wav2Vec2Model(BaseFairseqModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + + parser.add_argument("--pretrained-path", help="path to pretrained model", default=None) + parser.add_argument( + "--extractor-mode", + choices=["default", "layer_norm"], + help="mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with --normalize)", + ) + + parser.add_argument( + "--encoder-layers", + type=int, + metavar="L", + help="num encoder layers in the transformer", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="H", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="F", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="A", + help="num encoder attention heads", + ) + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + + parser.add_argument( + "--dropout", + type=float, + metavar="D", + help="dropout probability for the transformer", + ) + + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + + parser.add_argument( + "--activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN", + ) + + parser.add_argument( + "--final-dim", + type=int, + metavar="D", + help="project final representations and targets to this many dimensions", + ) + + parser.add_argument( + "--layer-norm-first", + action="store_true", + help="apply layernorm first in the transformer", + ) + + parser.add_argument( + "--encoder-layerdrop", + type=float, + help="probability of dropping a tarnsformer layer", + ) + + parser.add_argument( + "--conv-feature-layers", + type=str, + metavar="EXPR", + help="convolutional feature extraction layers [(dim, kernel_size, stride), ...]", + ) + + parser.add_argument( + "--logit-temp", type=float, help="temperature to divide logits by" + ) + + parser.add_argument( + "--quantize-targets", action="store_true", help="use quantized targets" + ) + + parser.add_argument( + "--quantize-input", action="store_true", help="use quantized inputs" + ) + + parser.add_argument( + "--same-quantizer", + action="store_true", + help="use same quantizer for inputs and targets", + ) + + parser.add_argument( + "--feature-grad-mult", + type=float, + help="multiply feature extractor var grads by this", + ) + + parser.add_argument( + "--latent-vars", + type=int, + metavar="N", + help="number of latent variables V in each group of the codebook", + ) + + parser.add_argument( + "--latent-groups", + type=int, + metavar="N", + help="number of groups G of latent variables in the codebook", + ) + + parser.add_argument( + "--latent-dim", + type=int, + metavar="N", + help="if set, uses this dimensionality for latent variables. otherwise uses final_dim / latent_groups", + ) + + parser.add_argument("--mask-length", type=int, help="mask length") + + parser.add_argument( + "--mask-prob", type=float, help="probability of replacing a token with mask" + ) + + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-other", + type=float, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indices", + ) + + parser.add_argument( + "--no-mask-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-min-space", + type=int, + help="min space between spans (if no overlap is enabled)", + ) + + parser.add_argument( + "--mask-channel-length", + type=int, + help="repeat the mask indices multiple times", + ) + + parser.add_argument( + "--mask-channel-prob", + type=float, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--mask-channel-other", + type=float, + help="secondary mask argument (used for more complex distributions), see help in compute_mask_indices", + ) + + parser.add_argument( + "--no-mask-channel-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--mask-channel-min-space", + type=int, + help="min space between spans (if no overlap is enabled)", + ) + + parser.add_argument( + "--dropout-input", + type=float, + metavar="D", + help="dropout to apply to the input (after feat extr)", + ) + + parser.add_argument( + "--dropout-features", + type=float, + metavar="D", + help="dropout to apply to the features (after feat extr)", + ) + + parser.add_argument( + "--num-negatives", type=int, metavar="N", help="number of negative examples" + ) + + parser.add_argument( + "--negatives-from-everywhere", + action="store_true", + help="sample negatives from everywhere, not just masked states", + ) + + parser.add_argument( + "--cross-sample-negatives", + type=int, + metavar="N", + help="num of cross sampled negatives", + ) + + parser.add_argument( + "--codebook-negatives", + type=int, + metavar="N", + help="num of codebook sampled negatives", + ) + + parser.add_argument( + "--conv-pos", + type=int, + metavar="N", + help="number of filters for convolutional positional embeddings", + ) + + parser.add_argument( + "--conv-pos-groups", + type=int, + metavar="N", + help="number of groups for convolutional positional embedding", + ) + + parser.add_argument( + "--latent-temp", + type=str, + metavar="D", + help="temperature for latent variable sampling. can be tuple of 3 values (start, end, decay)", + ) + + parser.add_argument( + "--target-glu", action="store_true", help="adds projection + glu to targets" + ) + + parser.add_argument( + "--conv-bias", action="store_true", help="include bias in conv encoder" + ) + parser.add_argument( + "--transpose", action="store_true" + ) + + def __init__(self, args): + super().__init__() + self.args = args + + feature_enc_layers = eval(args.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=args.extractor_mode, + conv_bias=args.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, args.encoder_embed_dim) + if self.embed != args.encoder_embed_dim and not args.quantize_input + else None + ) + + self.mask_prob = args.mask_prob + self.mask_selection = args.mask_selection + self.mask_other = args.mask_other + self.mask_length = args.mask_length + self.no_mask_overlap = args.no_mask_overlap + self.mask_min_space = args.mask_min_space + + self.mask_channel_prob = args.mask_channel_prob + self.mask_channel_selection = args.mask_channel_selection + self.mask_channel_other = args.mask_channel_other + self.mask_channel_length = args.mask_channel_length + self.no_mask_channel_overlap = args.no_mask_channel_overlap + self.mask_channel_min_space = args.mask_channel_min_space + + self.dropout_input = nn.Dropout(args.dropout_input) + self.dropout_features = nn.Dropout(args.dropout_features) + + self.feature_grad_mult = args.feature_grad_mult + + self.quantizer = None + self.input_quantizer = None + + self.n_negatives = args.num_negatives + self.cross_sample_negatives = args.cross_sample_negatives + self.codebook_negatives = args.codebook_negatives + self.negatives_from_everywhere = args.negatives_from_everywhere + + self.logit_temp = args.logit_temp + self.transpose = args.transpose + + final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim + + if args.quantize_targets: + vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim + self.quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=args.latent_vars, + temp=eval(args.latent_temp), + groups=args.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_q = nn.Linear(vq_dim, final_dim) + else: + self.project_q = nn.Linear(self.embed, final_dim) + + if args.quantize_input: + if args.same_quantizer and self.quantizer is not None: + vq_dim = final_dim + self.input_quantizer = self.quantizer + else: + vq_dim = ( + args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim + ) + self.input_quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=args.latent_vars, + temp=eval(args.latent_temp), + groups=args.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + ) + self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) + + self.mask_emb = nn.Parameter( + torch.FloatTensor(args.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(args) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if args.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + if not self.transpose: + self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) + self.q2h_proj = nn.Linear(final_dim, args.encoder_embed_dim) + else: + self.final_proj = nn.Linear(final_dim, args.encoder_embed_dim) + self.q2h_proj = None + + if getattr(args, 'pretrained_path', None) is not None: + state = checkpoint_utils.load_checkpoint_to_cpu( + args.pretrained_path)['model'] + from collections import OrderedDict + new_state = OrderedDict() + model_state = self.state_dict() + for k, v in state.items(): + if 'w2v_encoder.w2v_model.' in k: + k = k.replace('w2v_encoder.w2v_model.', '') + if k in model_state and model_state[k].size() == v.size(): + new_state[k] = v + logging.info('load parameter {} in Wav2vec2 from path'.format(k, args.pretrained_path)) + self.load_state_dict(new_state, strict=False) + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + @classmethod + def build_model(cls, args, task=None): + """Build a new model instance.""" + + # make sure all arguments are present + base_architecture(args) + + return cls(args) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def sample_negatives(self, y, num): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + y = y.view(-1, fsz) # BTC => (BxT)C + + cross_high = tsz * bsz + high = tsz + with torch.no_grad(): + assert high > 1, f"{bsz,tsz,fsz}" + + if self.n_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.n_negatives) + .flatten() + ) + + neg_idxs = torch.randint( + low=0, high=high - 1, size=(bsz, self.n_negatives * num) + ) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.cross_sample_negatives) + .flatten() + ) + + cross_neg_idxs = torch.randint( + low=0, + high=cross_high - 1, + size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view( + bsz, num, self.n_negatives + self.cross_sample_negatives, fsz + ).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def compute_preds(self, x, y, negatives): + + neg_is_pos = (y == negatives).all(-1) + y = y.unsqueeze(0) + targets = torch.cat([y, negatives], dim=0) + + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) + + logits /= self.logit_temp + + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + + return logits + + def forward(self, source, padding_mask=None, mask=True, features_only=False): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + padding_mask = padding_mask.all(-1) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + num_vars = None + code_ppl = None + prob_ppl = None + curr_temp = None + + if self.input_quantizer: + q = self.input_quantizer(features, produce_targets=False) + features = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + features = self.project_inp(features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask) + if mask_indices is not None: + y = unmasked_features[mask_indices].view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + y = unmasked_features + else: + x = features + y = unmasked_features + mask_indices = None + + x = self.encoder(x, padding_mask=padding_mask) + + + if features_only: + return {"x": x, "padding_mask": padding_mask} + + results = {"features": x, "feautre_padding_mask": padding_mask} + + if self.quantizer: + q = self.quantizer(y, produce_targets=False) + y = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + + y = self.project_q(y) + + q = self.quantizer(unmasked_features, produce_targets=False) + results["q"] = q["x"] + if self.negatives_from_everywhere: + #neg_cands = self.quantizer(unmasked_features, produce_targets=False) + negs = self.project_q(q['x']) + negs, _ = self.sample_negatives(negs, y.size(1)) + + else: + negs, _ = self.sample_negatives(y, y.size(1)) + + if self.codebook_negatives > 0: + cb_negs = self.quantizer.sample_from_codebook( + y.size(0) * y.size(1), self.codebook_negatives + ) + cb_negs = cb_negs.view( + self.codebook_negatives, y.size(0), y.size(1), -1 + ) # order doesnt matter + cb_negs = self.project_q(cb_negs) + negs = torch.cat([negs, cb_negs], dim=0) + else: + y = self.project_q(y) + + if self.negatives_from_everywhere: + negs, _ = self.sample_negatives(unmasked_features, y.size(1)) + negs = self.project_q(negs) + else: + negs, _ = self.sample_negatives(y, y.size(1)) + + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + + if not self.transpose: + x = self.final_proj(x) + results["q"] = self.q2h_proj(results["q"]) + else: + y = self.final_proj(y) + negs = self.final_proj(negs) + results["q"] = self.final_proj(self.project_q(results["q"])) + x = self.compute_preds(x, y, negs) + + + results['x'] = x + results['padding_mask'] = padding_mask + results['features_pen'] = features_pen + if prob_ppl is not None: + results["prob_perplexity"] = prob_ppl + results["code_perplexity"] = code_ppl + results["num_vars"] = num_vars + results["temp"] = curr_temp + + return results + + def quantize(self, x): + assert self.quantizer is not None + x = self.feature_extractor(x) + x = x.transpose(1, 2) + x = self.layer_norm(x) + return self.quantizer.forward_idx(x) + + def extract_features(self, source, padding_mask, mask=False): + res = self.forward(source, padding_mask, mask=mask, features_only=True) + return res["x"], res["padding_mask"] + + def get_logits(self, net_output): + logits = net_output["x"] + logits = logits.transpose(0, 2) + logits = logits.reshape(-1, logits.size(-1)) + return logits + + def get_targets(self, sample, net_output, expand_steps=True): + x = net_output["x"] + return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) + + def get_extra_losses(self, net_output): + pen = [] + + if "prob_perplexity" in net_output: + pen.append( + (net_output["num_vars"] - net_output["prob_perplexity"]) + / net_output["num_vars"] + ) + + if "features_pen" in net_output: + pen.append(net_output["features_pen"]) + + return pen + + def remove_pretraining_modules(self): + self.quantizer = None + self.project_q = None + self.target_glu = None + self.final_proj = None + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + + def forward(self, x): + + # BxT -> BxCxT + x = x.unsqueeze(1) + + for conv in self.conv_layers: + x = conv(x) + + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + for _ in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None): + x = self.extract_features(x, padding_mask) + + if self.layer_norm_first: + x = self.layer_norm(x) + + return x + + def extract_features(self, x, padding_mask=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x += x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) + layer_results.append(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x + + def max_positions(self): + """Maximum output length supported by the encoder.""" + return self.args.max_positions + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_fn = utils.get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn + + +@register_model_architecture("wav2vec2_s1", "wav2vec2_s1") +def base_architecture(args): + args.extractor_mode = getattr(args, "extractor_mode", "default") + + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + + args.final_dim = getattr(args, "final_dim", 0) + + args.layer_norm_first = getattr(args, "layer_norm_first", False) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) + + conv_feature_layers = "[(512, 10, 5)]" + conv_feature_layers += " + [(512, 8, 4)]" + conv_feature_layers += " + [(512, 4, 2)] * 3" + conv_feature_layers += " + [(512, 1, 1)]" + args.conv_feature_layers = getattr(args, "conv_feature_layers", conv_feature_layers) + + args.logit_temp = getattr(args, "logit_temp", 0.1) + + args.quantize_targets = getattr(args, "quantize_targets", False) + args.quantize_input = getattr(args, "quantize_input", False) + args.same_quantizer = getattr(args, "same_quantizer", False) + + args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) + + args.latent_vars = getattr(args, "latent_vars", 320) + args.latent_groups = getattr(args, "latent_groups", 2) + args.latent_dim = getattr(args, "latent_dim", 0) + + args.mask_length = getattr(args, "mask_length", 10) + args.mask_prob = getattr(args, "mask_prob", 0.65) + args.mask_selection = getattr(args, "mask_selection", "static") + args.mask_other = getattr(args, "mask_other", 0) + args.no_mask_overlap = getattr(args, "no_mask_overlap", False) + args.mask_min_space = getattr(args, "mask_min_space", 1) + + args.mask_channel_length = getattr(args, "mask_channel_length", 10) + args.mask_channel_prob = getattr(args, "mask_channel_prob", 0) + args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") + args.mask_channel_other = getattr(args, "mask_channel_other", 0) + args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) + args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1) + + args.dropout_input = getattr(args, "dropout_input", 0) + args.dropout_features = getattr(args, "dropout_features", 0) + + args.num_negatives = getattr(args, "num_negatives", 100) + args.negatives_from_everywhere = getattr(args, "negatives_from_everywhere", False) + args.cross_sample_negatives = getattr(args, "cross_sample_negatives", 0) + args.codebook_negatives = getattr(args, "codebook_negatives", 0) + + args.conv_pos = getattr(args, "conv_pos", 128) + args.conv_pos_groups = getattr(args, "conv_pos_groups", 16) + + args.latent_temp = getattr(args, "latent_temp", "(2,0.5,0.999995)") + + args.target_glu = getattr(args, "target_glu", False) + + args.conv_bias = getattr(args, "conv_bias", False) + args.transpose = getattr(args, "transpose", False) diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py new file mode 100644 index 0000000..e2326ac --- /dev/null +++ b/fairseq/modules/__init__.py @@ -0,0 +1,76 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +from .adaptive_input import AdaptiveInput +from .adaptive_softmax import AdaptiveSoftmax +from .beamable_mm import BeamableMM +from .character_token_embedder import CharacterTokenEmbedder +from .conv_tbc import ConvTBC +from .cross_entropy import cross_entropy +from .downsampled_multihead_attention import DownsampledMultiHeadAttention +from .dynamic_convolution import DynamicConv, DynamicConv1dTBC +from .dynamic_crf_layer import DynamicCRF +from .fairseq_dropout import FairseqDropout +from .fp32_group_norm import Fp32GroupNorm +from .gelu import gelu, gelu_accurate +from .grad_multiply import GradMultiply +from .gumbel_vector_quantizer import GumbelVectorQuantizer +from .kmeans_vector_quantizer import KmeansVectorQuantizer +from .layer_drop import LayerDropModuleList +from .layer_norm import Fp32LayerNorm, LayerNorm +from .learned_positional_embedding import LearnedPositionalEmbedding +from .lightweight_convolution import LightweightConv, LightweightConv1dTBC +from .linearized_convolution import LinearizedConvolution +from .multihead_attention import MultiheadAttention +from .positional_embedding import PositionalEmbedding +from .same_pad import SamePad +from .scalar_bias import ScalarBias +from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding +from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer +from .transformer_sentence_encoder import TransformerSentenceEncoder +from .transpose_last import TransposeLast +from .unfold import unfold1d +from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer +from .vggblock import VGGBlock + +__all__ = [ + "AdaptiveInput", + "AdaptiveSoftmax", + "BeamableMM", + "CharacterTokenEmbedder", + "ConvTBC", + "cross_entropy", + "DownsampledMultiHeadAttention", + "DynamicConv1dTBC", + "DynamicConv", + "DynamicCRF", + "FairseqDropout", + "Fp32GroupNorm", + "Fp32LayerNorm", + "gelu", + "gelu_accurate", + "GradMultiply", + "GumbelVectorQuantizer", + "KmeansVectorQuantizer", + "LayerDropModuleList", + "LayerNorm", + "LearnedPositionalEmbedding", + "LightweightConv1dTBC", + "LightweightConv", + "LinearizedConvolution", + "MultiheadAttention", + "PositionalEmbedding", + "SamePad", + "ScalarBias", + "SinusoidalPositionalEmbedding", + "TransformerSentenceEncoderLayer", + "TransformerSentenceEncoder", + "TransformerDecoderLayer", + "TransformerEncoderLayer", + "TransposeLast", + "VGGBlock", + "unfold1d", +] diff --git a/fairseq/modules/adaptive_input.py b/fairseq/modules/adaptive_input.py new file mode 100644 index 0000000..446534a --- /dev/null +++ b/fairseq/modules/adaptive_input.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +import torch +from fairseq.modules.quant_noise import quant_noise +from torch import nn + + +class AdaptiveInput(nn.Module): + def __init__( + self, + vocab_size: int, + padding_idx: int, + initial_dim: int, + factor: float, + output_dim: int, + cutoff: List[int], + q_noise: float = 0, + qn_block_size: int = 8, + ): + super().__init__() + + if vocab_size > cutoff[-1]: + cutoff = cutoff + [vocab_size] + else: + assert ( + vocab_size == cutoff[-1] + ), "cannot specify cutoff larger than vocab size" + + self.cutoff = cutoff + self.embedding_dim = output_dim + self.padding_idx = padding_idx + + self.embeddings = nn.ModuleList() + for i in range(len(self.cutoff)): + prev = self.cutoff[i - 1] if i > 0 else 0 + size = self.cutoff[i] - prev + dim = int(initial_dim // (factor ** i)) + seq = nn.Sequential( + nn.Embedding(size, dim, self.padding_idx), + quant_noise( + nn.Linear(dim, output_dim, bias=False), q_noise, qn_block_size + ), + ) + + self.embeddings.append(seq) + self.padding_idx = None + self.padding_idx = padding_idx + + def init_weights(m): + if isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + elif hasattr(m, "weight"): + nn.init.xavier_uniform_(m.weight) + + self.apply(init_weights) + + self.register_buffer("_float_tensor", torch.FloatTensor(1)) + + def weights_for_band(self, band: int): + return self.embeddings[band][0].weight, self.embeddings[band][1].weight + + def forward(self, input: torch.Tensor): + result = self._float_tensor.new(input.shape + (self.embedding_dim,)) + for i in range(len(self.cutoff)): + mask = input.lt(self.cutoff[i]) + if i > 0: + mask.mul_(input.ge(self.cutoff[i - 1])) + chunk_input = input[mask] - self.cutoff[i - 1] + else: + chunk_input = input[mask] + if mask.any(): + result[mask] = self.embeddings[i](chunk_input) + return result diff --git a/fairseq/modules/adaptive_softmax.py b/fairseq/modules/adaptive_softmax.py new file mode 100644 index 0000000..ae0c77b --- /dev/null +++ b/fairseq/modules/adaptive_softmax.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import operator + +import torch +import torch.nn.functional as F +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise +from torch import nn + + +class TiedLinear(nn.Module): + def __init__(self, weight, transpose): + super().__init__() + self.weight = weight + self.transpose = transpose + + def forward(self, input): + return F.linear(input, self.weight.t() if self.transpose else self.weight) + + +class TiedHeadModule(nn.Module): + def __init__(self, weights, input_dim, num_classes, q_noise, qn_block_size): + super().__init__() + tied_emb, _ = weights + self.num_words, emb_dim = tied_emb.size() + + self.word_proj = quant_noise( + TiedLinear(tied_emb, transpose=False), q_noise, qn_block_size + ) + if input_dim != emb_dim: + self.word_proj = nn.Sequential( + quant_noise( + nn.Linear(input_dim, emb_dim, bias=False), q_noise, qn_block_size + ), + self.word_proj, + ) + + self.class_proj = quant_noise( + nn.Linear(input_dim, num_classes, bias=False), q_noise, qn_block_size + ) + self.out_dim = self.num_words + num_classes + + self.register_buffer("_float_tensor", torch.FloatTensor(1)) + + def forward(self, input): + inp_sz = functools.reduce(operator.mul, input.shape[:-1], 1) + out = self._float_tensor.new(inp_sz, self.out_dim) + out[:, : self.num_words] = self.word_proj(input.view(inp_sz, -1)) + out[:, self.num_words :] = self.class_proj(input.view(inp_sz, -1)) + return out + + +class AdaptiveSoftmax(nn.Module): + """ + This is an implementation of the efficient softmax approximation for + graphical processing units (GPU), described in the paper "Efficient softmax + approximation for GPUs" (http://arxiv.org/abs/1609.04309). + """ + + def __init__( + self, + vocab_size, + input_dim, + cutoff, + dropout, + factor=4.0, + adaptive_inputs=None, + tie_proj=False, + q_noise=0, + qn_block_size=8, + ): + super().__init__() + + if vocab_size > cutoff[-1]: + cutoff = cutoff + [vocab_size] + else: + assert ( + vocab_size == cutoff[-1] + ), "cannot specify cutoff larger than vocab size" + + output_dim = cutoff[0] + len(cutoff) - 1 + + self.vocab_size = vocab_size + self.cutoff = cutoff + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.input_dim = input_dim + self.factor = factor + self.q_noise = q_noise + self.qn_block_size = qn_block_size + + self.lsm = nn.LogSoftmax(dim=1) + + if adaptive_inputs is not None: + self.head = TiedHeadModule( + adaptive_inputs.weights_for_band(0), + input_dim, + len(cutoff) - 1, + self.q_noise, + self.qn_block_size, + ) + else: + self.head = quant_noise( + nn.Linear(input_dim, output_dim, bias=False), + self.q_noise, + self.qn_block_size, + ) + + self._make_tail(adaptive_inputs, tie_proj) + + def init_weights(m): + if ( + hasattr(m, "weight") + and not isinstance(m, TiedLinear) + and not isinstance(m, TiedHeadModule) + ): + nn.init.xavier_uniform_(m.weight) + + self.apply(init_weights) + + self.register_buffer("version", torch.LongTensor([1])) + + def _make_tail(self, adaptive_inputs=None, tie_proj=False): + self.tail = nn.ModuleList() + for i in range(len(self.cutoff) - 1): + dim = int(self.input_dim // self.factor ** (i + 1)) + + tied_emb, tied_proj = ( + adaptive_inputs.weights_for_band(i + 1) + if adaptive_inputs is not None + else (None, None) + ) + + if tied_proj is not None: + if tie_proj: + proj = quant_noise( + TiedLinear(tied_proj, transpose=True), + self.q_noise, + self.qn_block_size, + ) + else: + proj = quant_noise( + nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False), + self.q_noise, + self.qn_block_size, + ) + else: + proj = quant_noise( + nn.Linear(self.input_dim, dim, bias=False), + self.q_noise, + self.qn_block_size, + ) + + if tied_emb is None: + out_proj = nn.Linear( + dim, self.cutoff[i + 1] - self.cutoff[i], bias=False + ) + else: + out_proj = TiedLinear(tied_emb, transpose=False) + + m = nn.Sequential( + proj, + nn.Dropout(self.dropout_module.p), + quant_noise(out_proj, self.q_noise, self.qn_block_size), + ) + + self.tail.append(m) + + def upgrade_state_dict_named(self, state_dict, name): + version_name = name + ".version" + if version_name not in state_dict: + raise Exception("This version of the model is no longer supported") + + def adapt_target(self, target): + """ + In order to be efficient, the AdaptiveSoftMax does not compute the + scores for all the word of the vocabulary for all the examples. It is + thus necessary to call the method adapt_target of the AdaptiveSoftMax + layer inside each forward pass. + """ + + target = target.view(-1) + new_target = [target.clone()] + target_idxs = [] + + for i in range(len(self.cutoff) - 1): + mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) + new_target[0][mask] = self.cutoff[0] + i + + if mask.any(): + target_idxs.append(mask.nonzero(as_tuple=False).squeeze(1)) + new_target.append(target[mask].add(-self.cutoff[i])) + else: + target_idxs.append(None) + new_target.append(None) + + return new_target, target_idxs + + def forward(self, input, target): + """ + Args: + input: (b x t x d) + target: (b x t) + Returns: + 2 lists: output for each cutoff section and new targets by cut off + """ + + input = input.contiguous().view(-1, input.size(-1)) + input = self.dropout_module(input) + + new_target, target_idxs = self.adapt_target(target) + output = [self.head(input)] + + for i in range(len(target_idxs)): + if target_idxs[i] is not None: + output.append(self.tail[i](input.index_select(0, target_idxs[i]))) + else: + output.append(None) + + return output, new_target + + def get_log_prob(self, input, target): + """ + Computes the log probabilities for all the words of the vocabulary, + given a 2D tensor of hidden vectors. + """ + + bsz, length, dim = input.size() + input = input.contiguous().view(-1, dim) + + if target is not None: + _, target_idxs = self.adapt_target(target) + else: + target_idxs = None + + head_y = self.head(input) + log_probs = head_y.new_zeros(input.size(0), self.vocab_size) + + head_sz = self.cutoff[0] + len(self.tail) + log_probs[:, :head_sz] = self.lsm(head_y) + tail_priors = log_probs[:, self.cutoff[0] : head_sz].clone() + + for i in range(len(self.tail)): + start = self.cutoff[i] + end = self.cutoff[i + 1] + + if target_idxs is None: + tail_out = log_probs[:, start:end] + tail_out.copy_(self.tail[i](input)) + log_probs[:, start:end] = self.lsm(tail_out).add_( + tail_priors[:, i, None] + ) + elif target_idxs[i] is not None: + idxs = target_idxs[i] + tail_out = log_probs[idxs, start:end] + tail_out.copy_(self.tail[i](input[idxs])) + log_probs[idxs, start:end] = self.lsm(tail_out).add_( + tail_priors[idxs, i, None] + ) + + log_probs = log_probs.view(bsz, length, -1) + return log_probs diff --git a/fairseq/modules/beamable_mm.py b/fairseq/modules/beamable_mm.py new file mode 100644 index 0000000..eff1a46 --- /dev/null +++ b/fairseq/modules/beamable_mm.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +class BeamableMM(nn.Module): + """This module provides an optimized MM for beam decoding with attention. + + It leverage the fact that the source-side of the input is replicated beam + times and the target-side of the input is of width one. This layer speeds up + inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} + with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. + """ + + def __init__(self, beam_size=None): + super(BeamableMM, self).__init__() + self.beam_size = beam_size + + def forward(self, input1, input2): + if ( + not self.training + and self.beam_size is not None # test mode + and input1.dim() == 3 # beam size is set + and input1.size(1) # only support batched input + == 1 # single time step update + ): + bsz, beam = input1.size(0), self.beam_size + + # bsz x 1 x nhu --> bsz/beam x beam x nhu + input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) + + # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu + input2 = input2.unfold(0, beam, beam)[:, :, :, 0] + + # use non batched operation if bsz = beam + if input1.size(0) == 1: + output = torch.mm(input1[0, :, :], input2[0, :, :]) + else: + output = input1.bmm(input2) + return output.view(bsz, 1, -1) + else: + return input1.bmm(input2) + + def set_beam_size(self, beam_size): + self.beam_size = beam_size diff --git a/fairseq/modules/character_token_embedder.py b/fairseq/modules/character_token_embedder.py new file mode 100644 index 0000000..181221b --- /dev/null +++ b/fairseq/modules/character_token_embedder.py @@ -0,0 +1,214 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from fairseq.data import Dictionary +from torch import nn + + +CHAR_PAD_IDX = 0 +CHAR_EOS_IDX = 257 + + +logger = logging.getLogger(__name__) + + +class CharacterTokenEmbedder(torch.nn.Module): + def __init__( + self, + vocab: Dictionary, + filters: List[Tuple[int, int]], + char_embed_dim: int, + word_embed_dim: int, + highway_layers: int, + max_char_len: int = 50, + char_inputs: bool = False, + ): + super(CharacterTokenEmbedder, self).__init__() + + self.onnx_trace = False + self.embedding_dim = word_embed_dim + self.max_char_len = max_char_len + self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0) + self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim)) + self.eos_idx, self.unk_idx = 0, 1 + self.char_inputs = char_inputs + + self.convolutions = nn.ModuleList() + for width, out_c in filters: + self.convolutions.append( + nn.Conv1d(char_embed_dim, out_c, kernel_size=width) + ) + + last_dim = sum(f[1] for f in filters) + + self.highway = Highway(last_dim, highway_layers) if highway_layers > 0 else None + + self.projection = nn.Linear(last_dim, word_embed_dim) + + assert ( + vocab is not None or char_inputs + ), "vocab must be set if not using char inputs" + self.vocab = None + if vocab is not None: + self.set_vocab(vocab, max_char_len) + + self.reset_parameters() + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def set_vocab(self, vocab, max_char_len): + word_to_char = torch.LongTensor(len(vocab), max_char_len) + + truncated = 0 + for i in range(len(vocab)): + if i < vocab.nspecial: + char_idxs = [0] * max_char_len + else: + chars = vocab[i].encode() + # +1 for padding + char_idxs = [c + 1 for c in chars] + [0] * (max_char_len - len(chars)) + if len(char_idxs) > max_char_len: + truncated += 1 + char_idxs = char_idxs[:max_char_len] + word_to_char[i] = torch.LongTensor(char_idxs) + + if truncated > 0: + logger.info( + "truncated {} words longer than {} characters".format( + truncated, max_char_len + ) + ) + + self.vocab = vocab + self.word_to_char = word_to_char + + @property + def padding_idx(self): + return Dictionary().pad() if self.vocab is None else self.vocab.pad() + + def reset_parameters(self): + nn.init.xavier_normal_(self.char_embeddings.weight) + nn.init.xavier_normal_(self.symbol_embeddings) + nn.init.xavier_uniform_(self.projection.weight) + + nn.init.constant_( + self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.0 + ) + nn.init.constant_(self.projection.bias, 0.0) + + def forward( + self, + input: torch.Tensor, + ): + if self.char_inputs: + chars = input.view(-1, self.max_char_len) + pads = chars[:, 0].eq(CHAR_PAD_IDX) + eos = chars[:, 0].eq(CHAR_EOS_IDX) + if eos.any(): + if self.onnx_trace: + chars = torch.where(eos.unsqueeze(1), chars.new_zeros(1), chars) + else: + chars[eos] = 0 + + unk = None + else: + flat_words = input.view(-1) + chars = self.word_to_char[flat_words.type_as(self.word_to_char)].type_as( + input + ) + pads = flat_words.eq(self.vocab.pad()) + eos = flat_words.eq(self.vocab.eos()) + unk = flat_words.eq(self.vocab.unk()) + + word_embs = self._convolve(chars) + if self.onnx_trace: + if pads.any(): + word_embs = torch.where( + pads.unsqueeze(1), word_embs.new_zeros(1), word_embs + ) + if eos.any(): + word_embs = torch.where( + eos.unsqueeze(1), self.symbol_embeddings[self.eos_idx], word_embs + ) + if unk is not None and unk.any(): + word_embs = torch.where( + unk.unsqueeze(1), self.symbol_embeddings[self.unk_idx], word_embs + ) + else: + if pads.any(): + word_embs[pads] = 0 + if eos.any(): + word_embs[eos] = self.symbol_embeddings[self.eos_idx] + if unk is not None and unk.any(): + word_embs[unk] = self.symbol_embeddings[self.unk_idx] + + return word_embs.view(input.size()[:2] + (-1,)) + + def _convolve( + self, + char_idxs: torch.Tensor, + ): + char_embs = self.char_embeddings(char_idxs) + char_embs = char_embs.transpose(1, 2) # BTC -> BCT + + conv_result = [] + + for conv in self.convolutions: + x = conv(char_embs) + x, _ = torch.max(x, -1) + x = F.relu(x) + conv_result.append(x) + + x = torch.cat(conv_result, dim=-1) + + if self.highway is not None: + x = self.highway(x) + x = self.projection(x) + + return x + + +class Highway(torch.nn.Module): + """ + A `Highway layer `_. + Adopted from the AllenNLP implementation. + """ + + def __init__(self, input_dim: int, num_layers: int = 1): + super(Highway, self).__init__() + self.input_dim = input_dim + self.layers = nn.ModuleList( + [nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)] + ) + self.activation = nn.ReLU() + + self.reset_parameters() + + def reset_parameters(self): + for layer in self.layers: + # As per comment in AllenNLP: + # We should bias the highway layer to just carry its input forward. We do that by + # setting the bias on `B(x)` to be positive, because that means `g` will be biased to + # be high, so we will carry the input forward. The bias on `B(x)` is the second half + # of the bias vector in each Linear layer. + nn.init.constant_(layer.bias[self.input_dim :], 1) + + nn.init.constant_(layer.bias[: self.input_dim], 0) + nn.init.xavier_normal_(layer.weight) + + def forward(self, x: torch.Tensor): + for layer in self.layers: + projection = layer(x) + proj_x, gate = projection.chunk(2, dim=-1) + proj_x = self.activation(proj_x) + gate = torch.sigmoid(gate) + x = gate * x + (gate.new_tensor([1]) - gate) * proj_x + return x diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py new file mode 100644 index 0000000..a4341fe --- /dev/null +++ b/fairseq/modules/checkpoint_activations.py @@ -0,0 +1,205 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Tuple, Union + +import torch + +from fairseq import utils + + +def checkpoint_wrapper(m): + """ + A friendlier wrapper for performing activation checkpointing. + + Compared to the PyTorch version, this version: + - wraps an nn.Module, so that all subsequent calls will use checkpointing + - handles keyword arguments in the forward + - handles non-Tensor outputs from the forward + + Usage:: + + checkpointed_module = checkpoint_wrapper(my_module) + a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) + """ + original_forward = m.forward + + def _checkpointed_forward(*args, **kwargs): + # Autograd Functions in PyTorch work best with positional args, since + # the backward must return gradients (or None) for every input argument. + # We can flatten keyword arguments to make this easier. + kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) + parent_ctx_dict = {} + output = CheckpointFunction.apply( + original_forward, parent_ctx_dict, kwarg_keys, *flat_args + ) + if isinstance(output, torch.Tensor): + return output + else: + packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] + if packed_non_tensor_outputs: + output = unpack_non_tensors(output, packed_non_tensor_outputs) + return output + + m.forward = _checkpointed_forward + return m + + +def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]: + """ + Usage:: + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == [1, 2] + assert kwargs == {"a": 3, "b": 4} + """ + kwarg_keys = [] + flat_args = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + return kwarg_keys, flat_args + + +def unpack_kwargs( + kwarg_keys: List[str], flat_args: List[Any] +) -> Tuple[List[Any], Dict[str, Any]]: + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} + return args, kwargs + + +def split_non_tensors( + mixed: Union[torch.Tensor, Tuple[Any]] +) -> Tuple[Tuple[torch.Tensor], Dict[str, List[Any]]]: + """ + Usage:: + + x = torch.Tensor([1]) + y = torch.Tensor([2]) + tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (x, y, None, 3) + """ + if isinstance(mixed, torch.Tensor): + return (mixed,), None + assert isinstance(mixed, tuple) + tensors = [] + packed_non_tensors = {"is_tensor": [], "objects": []} + for o in mixed: + if isinstance(o, torch.Tensor): + packed_non_tensors["is_tensor"].append(True) + tensors.append(o) + else: + packed_non_tensors["is_tensor"].append(False) + packed_non_tensors["objects"].append(o) + return tuple(tensors), packed_non_tensors + + +def unpack_non_tensors( + tensors: Tuple[torch.Tensor], + packed_non_tensors: Dict[str, List[Any]], +) -> Tuple[Any]: + if packed_non_tensors is None: + return tensors + assert isinstance(packed_non_tensors, dict) + mixed = [] + is_tensor_list = packed_non_tensors["is_tensor"] + objects = packed_non_tensors["objects"] + assert len(tensors) + len(objects) == len(is_tensor_list) + obj_i = tnsr_i = 0 + for is_tensor in is_tensor_list: + if is_tensor: + mixed.append(tensors[tnsr_i]) + tnsr_i += 1 + else: + mixed.append(objects[obj_i]) + obj_i += 1 + return tuple(mixed) + + +class CheckpointFunction(torch.autograd.Function): + """Similar to the torch version, but support non-Tensor outputs. + + The caller is expected to provide a dict (*parent_ctx_dict*) that will hold + the non-Tensor outputs. These should be combined with the Tensor *outputs* + by calling ``unpack_non_tensors``. + """ + + @staticmethod + def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): + if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation + torch.utils.checkpoint.check_backward_validity(args) + + ctx.run_function = run_function + ctx.kwarg_keys = kwarg_keys + ctx.fwd_rng_state = utils.get_rng_state() + + tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) + ctx.save_for_backward(*tensor_inputs) + ctx.packed_non_tensor_inputs = packed_non_tensor_inputs + + with torch.no_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) + outputs = run_function(*unpacked_args, **unpacked_kwargs) + + if isinstance(outputs, torch.Tensor): + return outputs + else: + assert isinstance(outputs, tuple) + # Autograd Functions don't like non-Tensor outputs. We can split the + # non-Tensor and Tensor outputs, returning the former by reference + # through *parent_ctx_dict* and returning the latter directly. + outputs, packed_non_tensor_outputs = split_non_tensors(outputs) + parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad(), please use .backward() if possible" + ) + + tensor_inputs = ctx.saved_tensors + tensor_inputs = torch.utils.checkpoint.detach_variable(tensor_inputs) + inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) + + # Store the current states. + bwd_rng_state = utils.get_rng_state() + + # Set the states to what it used to be before the forward pass. + utils.set_rng_state(ctx.fwd_rng_state) + + with torch.enable_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) + outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) + tensor_outputs, _ = split_non_tensors(outputs) + + # Set the states back to what it was at the start of this function. + utils.set_rng_state(bwd_rng_state) + + # Run backward() with only Tensors that require grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(tensor_outputs)): + if tensor_outputs[i].requires_grad: + outputs_with_grad.append(tensor_outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "None of the outputs have requires_grad=True, " + "this checkpoint() is not necessary" + ) + + torch.autograd.backward(outputs_with_grad, args_with_grad) + + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs + ) + return (None, None, None) + grads diff --git a/fairseq/modules/conv_tbc.py b/fairseq/modules/conv_tbc.py new file mode 100644 index 0000000..2dc46c4 --- /dev/null +++ b/fairseq/modules/conv_tbc.py @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.nn.modules.utils import _single + + +class ConvTBC(torch.nn.Module): + """1D convolution over an input of shape (time x batch x channel) + + The implementation uses gemm to perform the convolution. This implementation + is faster than cuDNN for small kernel sizes. + """ + + def __init__(self, in_channels, out_channels, kernel_size, padding=0): + super(ConvTBC, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _single(kernel_size) + self.padding = _single(padding) + + self.weight = torch.nn.Parameter( + torch.Tensor(self.kernel_size[0], in_channels, out_channels) + ) + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + + def forward(self, input): + return torch.conv_tbc( + input.contiguous(), self.weight, self.bias, self.padding[0] + ) + + def __repr__(self): + s = ( + "{name}({in_channels}, {out_channels}, kernel_size={kernel_size}" + ", padding={padding}" + ) + if self.bias is None: + s += ", bias=False" + s += ")" + return s.format(name=self.__class__.__name__, **self.__dict__) diff --git a/fairseq/modules/cross_entropy.py b/fairseq/modules/cross_entropy.py new file mode 100644 index 0000000..0d2beb4 --- /dev/null +++ b/fairseq/modules/cross_entropy.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn.functional as F + + +logger = logging.getLogger(__name__) + + +def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + return F.nll_loss( + lprobs, + target, + ignore_index=ignore_index, + reduction=reduction, + ) + + +try: + import xentropy_cuda + from apex.contrib import xentropy + + logger.info("using fused cross entropy") + + def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): + if logits.device == torch.device("cpu"): + return _cross_entropy_pytorch(logits, target, ignore_index, reduction) + else: + half_to_float = logits.dtype == torch.half + losses = xentropy.SoftmaxCrossEntropyLoss.apply( + logits, + target, + 0.0, + ignore_index, + half_to_float, + ) + if reduction == "sum": + return losses.sum() + elif reduction == "mean": + if ignore_index >= 0: + return losses.sum() / target.ne(ignore_index).sum() + else: + return losses.mean() + elif reduction == "none": + return losses + else: + raise NotImplementedError + + +except ImportError: + + def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): + return _cross_entropy_pytorch(logits, target, ignore_index, reduction) diff --git a/fairseq/modules/cuda_utils.cu b/fairseq/modules/cuda_utils.cu new file mode 100644 index 0000000..516f1d9 --- /dev/null +++ b/fairseq/modules/cuda_utils.cu @@ -0,0 +1,203 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + + +template +constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + return (a + b - 1) / b; +} + + +template +__inline__ __device__ +void zeroSharedMem(scalar_t* data) { + /* + Given an array of length FS + SB, zero out the first padding_l and last + (FS - padding_l) values in the array + */ + + int tid = threadIdx.x; + + if (FS < SB) { + + // zero all if we have enough threads in a block to do all of them + if (tid < padding_l || tid > SB - FS + padding_l - 1) { + data[tid] = scalar_t(0.0); + } + } else { + + // otherwise zero out one block at a time + const int numIterations = divUp(FS, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if (tid + offset < padding_l) { + data[tid + offset] = scalar_t(0.0); + } else if (tid + offset < FS) { + data[SB + tid + offset] = scalar_t(0.0); + } + } + } +} + +template +__inline__ __device__ +scalar_t warpReduce(scalar_t data) { + /* + Reduce an array within each warp. After processing all values in warp will + caontain the sum of all original values in that warp. + + data - pointer to data to reduce + */ + data += __shfl_xor_sync(SHFL_MASK, data, 16); + data += __shfl_xor_sync(SHFL_MASK, data, 8); + data += __shfl_xor_sync(SHFL_MASK, data, 4); + data += __shfl_xor_sync(SHFL_MASK, data, 2); + data += __shfl_xor_sync(SHFL_MASK, data, 1); + return data; +} + +template +__inline__ __device__ +scalar_t blockReduce(scalar_t data) { + /* + Reduce an entire array on the block level. After processing, the + first value in the array will contain the reduced sum. + + data - pointer to data to reduce + */ + + static __shared__ scalar_t warpSum[32]; + const int tid = threadIdx.x; + int wid = tid / 32; + int lane = tid % 32; + + __syncthreads(); + + // reduce each warp then write to shared memory + scalar_t sum = warpReduce(data); + if (lane == 0) { + warpSum[wid] = sum; + } + + __syncthreads(); + + scalar_t v; + // perform final sum of partial warp sums + if (tid < blockDim.x / 32) { + v = warpSum[lane]; + } else { + v = scalar_t(0.0); + } + + if (wid == 0) { + v = warpReduce(v); + } + __syncthreads(); + + return v; +} + +void checkCudaStatus(cudaError_t status, int lineNumber = -1) { + + if (status != cudaSuccess) { + std::cout << cudaGetErrorString(status) + << " at line " << lineNumber << std::endl; + std::cout << "Exiting" << std::endl; + exit(1); + } +} + +template +__device__ +void load_input_to_shared(const scalar_t* input, // global memory + int inputOffset, int sequenceLength, + int iteration, int numIterations, + bool no_prev, scalar_t* output /* shared memory */) { + /* + Load a block size of input into shared memory with + right and left overhang of total size FS. If previously + loaded memory, overlap will be shifted over to reduce + global memory access + + input - pointer to start of channel sequence + inputOffset - how far in the sequence to start loading + sequenceLength - total length of sequence + iteration - which block of sequence we are loading + numIterations - total number of blocks to load + no_prev - whether to load the whole block if the previous block + wasn't loaded + output - shared memory to write input to + */ + + const int tid = threadIdx.x; + + // Load the left "overhang" of input + if (iteration > 0) { + if (padding_l < SB) { + + // load all at once + if (tid < padding_l) { + output[tid] = (no_prev) ? input[inputOffset - padding_l + tid] : output[tid + SB]; + } + } else { + + // load in chunks of size SB + int numIterations = divUp(padding_l, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if ((tid + offset) < padding_l) { + output[tid + offset] = (no_prev) ? input[inputOffset - padding_l + tid + offset] : output[tid + offset + SB]; + } + } + } + } + + // Load the right "overhang" of input + if (iteration < (numIterations - 1)) { + const int elementsLeft = sequenceLength - (iteration+1) * SB; + + if ((FS - padding_l) < SB) { + + // load all at once + if (tid < (FS - padding_l)) { + output[padding_l + SB + tid] = (tid < elementsLeft) ? input[inputOffset + SB + tid] : scalar_t(0.0); + } + } else { + + // load in chunks of size SB + int numIterations = divUp(FS - padding_l, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if ((tid + offset) < (FS - padding_l)) { + output[padding_l + SB + tid + offset] = ((tid + offset) < elementsLeft) ? input[inputOffset + SB + tid + offset] : scalar_t(0.0); + } + } + } + } + + // We should also clear out the right "overhang" + if (iteration == (numIterations - 1)) { + if ((FS - padding_l) < SB) { + + // clear out all at once + if (tid < (FS - padding_l)) { + output[padding_l + SB + tid] = scalar_t(0.0); + } + } else { + + // clear in chunks of size SB + int numIterations = divUp(FS - padding_l, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if ((tid + offset) < (FS - padding_l)) { + output[padding_l + SB + tid + offset] = scalar_t(0.0); + } + } + } + } + output[tid + padding_l] = ((inputOffset + tid) < sequenceLength) ? input[inputOffset + tid] : scalar_t(0.0); +} diff --git a/fairseq/modules/downsampled_multihead_attention.py b/fairseq/modules/downsampled_multihead_attention.py new file mode 100644 index 0000000..2cdece3 --- /dev/null +++ b/fairseq/modules/downsampled_multihead_attention.py @@ -0,0 +1,316 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.scalar_bias import scalar_bias + + +class SingleHeadAttention(nn.Module): + """ + Single-head attention that supports Gating and Downsampling + """ + + def __init__( + self, + out_channels, + embed_dim, + head_dim, + head_index, + dropout=0.0, + bias=True, + project_input=True, + gated=False, + downsample=False, + num_heads=1, + ): + super().__init__() + self.embed_dim = embed_dim + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.head_index = head_index + self.head_dim = head_dim + self.project_input = project_input + self.gated = gated + self.downsample = downsample + self.num_heads = num_heads + self.projection = None + + k_layers = [] + v_layers = [] + if self.downsample: + k_layers.append(Downsample(self.head_index)) + v_layers.append(Downsample(self.head_index)) + out_proj_size = self.head_dim + else: + out_proj_size = self.head_dim * self.num_heads + if self.gated: + k_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias)) + self.in_proj_q = GatedLinear(self.embed_dim, out_proj_size, bias=bias) + v_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias)) + else: + k_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias)) + self.in_proj_q = Linear(self.embed_dim, out_proj_size, bias=bias) + v_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias)) + + self.in_proj_k = nn.Sequential(*k_layers) + self.in_proj_v = nn.Sequential(*v_layers) + + if self.downsample: + self.out_proj = Linear(out_proj_size, self.head_dim, bias=bias) + else: + self.out_proj = Linear(out_proj_size, out_channels, bias=bias) + + self.scaling = self.head_dim ** -0.5 + + def forward( + self, + query, + key, + value, + mask_future_timesteps=False, + key_padding_mask=None, + use_scalar_bias=False, + ): + """Input shape: Time x Batch x Channel + Self-attention can be implemented by passing in the same arguments for + query, key and value. Future timesteps can be masked with the + `mask_future_timesteps` argument. Padding elements can be excluded from + the key by passing a binary ByteTensor (`key_padding_mask`) with shape: + batch x src_len, where padding elements are indicated by 1s. + """ + src_len, bsz, out_channels = key.size() + tgt_len = query.size(0) + assert list(query.size()) == [tgt_len, bsz, out_channels] + assert key.size() == value.size() + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.downsample: + size = bsz + else: + size = bsz * self.num_heads + + k = key + v = value + q = query + if self.project_input: + q = self.in_proj_q(q) + k = self.in_proj_k(k) + v = self.in_proj_v(v) + src_len = k.size()[0] + q *= self.scaling + + if not self.downsample: + q = q.view(tgt_len, size, self.head_dim) + k = k.view(src_len, size, self.head_dim) + v = v.view(src_len, size, self.head_dim) + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + if mask_future_timesteps: + assert ( + query.size() == key.size() + ), "mask_future_timesteps only applies to self-attention" + attn_weights *= torch.tril( + attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(), + diagonal=-1, + )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0) + attn_weights += torch.triu( + attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(), + diagonal=0, + )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0) + tgt_size = tgt_len + if use_scalar_bias: + attn_weights = scalar_bias(attn_weights, 2) + v = scalar_bias(v, 1) + tgt_size += 1 + + if key_padding_mask is not None: + # don't attend to padding symbols + if key_padding_mask.max() > 0: + if self.downsample: + attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len) + else: + attn_weights = attn_weights.view( + size, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + -math.inf, + ) + attn_weights = attn_weights.view(size, tgt_len, src_len) + attn_weights = F.softmax(attn_weights, dim=-1) + attn_weights = self.dropout_module(attn_weights) + + attn = torch.bmm(attn_weights, v) + if self.downsample: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + + attn = self.out_proj(attn) + + return attn, attn_weights + + +class DownsampledMultiHeadAttention(nn.ModuleList): + """ + Multi-headed attention with Gating and Downsampling + """ + + def __init__( + self, + out_channels, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + project_input=True, + gated=False, + downsample=False, + ): + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.downsample = downsample + self.gated = gated + self.project_input = project_input + assert self.head_dim * num_heads == embed_dim + + if self.downsample: + attention_heads = [] + for index in range(self.num_heads): + attention_heads.append( + SingleHeadAttention( + out_channels, + self.embed_dim, + self.head_dim, + index, + dropout, + bias, + self.project_input, + self.gated, + self.downsample, + self.num_heads, + ) + ) + super().__init__(modules=attention_heads) + self.out_proj = Linear(embed_dim, out_channels, bias=bias) + else: + # either we have a list of attention heads, or just one attention head + # if not being downsampled, we can do the heads with one linear layer instead of separate ones + super().__init__() + self.attention_module = SingleHeadAttention( + out_channels, + self.embed_dim, + self.head_dim, + 1, + dropout, + bias, + self.project_input, + self.gated, + self.downsample, + self.num_heads, + ) + + def forward( + self, + query, + key, + value, + mask_future_timesteps=False, + key_padding_mask=None, + use_scalar_bias=False, + ): + src_len, bsz, embed_dim = key.size() + tgt_len = query.size(0) + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + assert key.size() == value.size() + + tgt_size = tgt_len + if use_scalar_bias: + tgt_size += 1 + + attn = [] + attn_weights = [] + if self.downsample: + for attention_head_number in range(self.num_heads): + # call the forward of each attention head + _attn, _attn_weight = self[attention_head_number]( + query, + key, + value, + mask_future_timesteps, + key_padding_mask, + use_scalar_bias, + ) + attn.append(_attn) + attn_weights.append(_attn_weight) + full_attn = torch.cat(attn, dim=2) + full_attn = self.out_proj(full_attn) + return full_attn, attn_weights[0].clone() + else: + _attn, _attn_weight = self.attention_module( + query, + key, + value, + mask_future_timesteps, + key_padding_mask, + use_scalar_bias, + ) + attn.append(_attn) + attn_weights.append(_attn_weight) + full_attn = torch.cat(attn, dim=2) + full_attn_weights = torch.cat(attn_weights) + full_attn_weights = full_attn_weights.view( + bsz, self.num_heads, tgt_size, src_len + ) + full_attn_weights = full_attn_weights.sum(dim=1) / self.num_heads + return full_attn, full_attn_weights + + +class Downsample(nn.Module): + """ + Selects every nth element, where n is the index + """ + + def __init__(self, index): + super().__init__() + self.index = index + + def forward(self, x): + return x[:: self.index + 1] + + +def Linear(in_features, out_features, dropout=0.0, bias=True): + """Weight-normalized Linear layer (input: B x T x C)""" + m = nn.Linear(in_features, out_features, bias=bias) + m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) + m.bias.data.zero_() + return nn.utils.weight_norm(m) + + +def GatedLinear(in_features, out_features, dropout=0.0, bias=True): + """Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units""" + return nn.Sequential( + Linear(in_features, out_features * 4, dropout, bias), + nn.GLU(), + Linear(out_features * 2, out_features * 2, dropout, bias), + nn.GLU(), + Linear(out_features, out_features, dropout, bias), + ) diff --git a/fairseq/modules/dynamic_convolution.py b/fairseq/modules/dynamic_convolution.py new file mode 100644 index 0000000..5999a04 --- /dev/null +++ b/fairseq/modules/dynamic_convolution.py @@ -0,0 +1,304 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout + +from .unfold import unfold1d + + +def DynamicConv( + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + renorm_padding=False, + bias=False, + conv_bias=False, + query_size=None, + in_proj=False, +): + if torch.cuda.is_available(): + try: + from fairseq.modules.dynamicconv_layer import DynamicconvLayer + + return DynamicconvLayer( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + bias=bias, + ) + except ImportError as e: + print(e) + return DynamicConv1dTBC( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + bias=bias, + ) + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +@with_incremental_state +class DynamicConv1dTBC(nn.Module): + """Dynamic lightweight convolution taking T x B x C inputs + Args: + input_size: # of channels of the input + kernel_size: convolution channels + padding_l: padding to the left when using "same" padding + num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) + weight_dropout: the drop rate of the DropConnect to drop the weight + weight_softmax: normalize the weight with softmax before the convolution + renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) + bias: use bias + conv_bias: bias of the convolution + query_size: specified when feeding a different input as the query + in_proj: project the input and generate the filter together + + Shape: + Input: TxBxC, i.e. (timesteps, batch_size, input_size) + Output: TxBxC, i.e. (timesteps, batch_size, input_size) + + Attributes: + weight: the learnable weights of the module of shape + `(num_heads, 1, kernel_size)` + bias: the learnable bias of the module of shape `(input_size)` + """ + + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + renorm_padding=False, + bias=False, + conv_bias=False, + query_size=None, + in_proj=False, + ): + super().__init__() + self.input_size = input_size + self.query_size = input_size if query_size is None else query_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) + self.weight_softmax = weight_softmax + self.renorm_padding = renorm_padding + + if in_proj: + self.weight_linear = Linear( + self.input_size, self.input_size + num_heads * kernel_size * 1 + ) + else: + self.weight_linear = Linear( + self.query_size, num_heads * kernel_size * 1, bias=bias + ) + if conv_bias: + self.conv_bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.conv_bias = None + self.reset_parameters() + + @property + def in_proj(self): + return ( + self.weight_linear.out_features + == self.input_size + self.num_heads * self.kernel_size + ) + + def reset_parameters(self): + self.weight_linear.reset_parameters() + if self.conv_bias is not None: + nn.init.constant_(self.conv_bias, 0.0) + + def forward(self, x, incremental_state=None, query=None, unfold=None): + """Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C + args: + x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) + incremental_state: A dict to keep the state + unfold: unfold the input or not. If not, we use the matrix trick instead + query: use the specified query to predict the conv filters + """ + unfold = ( + x.size(0) > 512 if unfold is None else unfold + ) # use unfold mode as default for long sequence to save memory + unfold = unfold or (incremental_state is not None) + assert query is None or not self.in_proj + + if query is None: + query = x + if unfold: + output = self._forward_unfolded(x, incremental_state, query) + else: + output = self._forward_expanded(x, incremental_state, query) + + if self.conv_bias is not None: + output = output + self.conv_bias.view(1, 1, -1) + return output + + def _forward_unfolded(self, x, incremental_state, query): + """The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.""" + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + if self.in_proj: + proj = self.weight_linear(x) + x = proj.narrow(2, 0, self.input_size).contiguous() + weight = ( + proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1) + ) + else: + weight = self.weight_linear(query).view(T * B * H, -1) + + # renorm_padding is only implemented in _forward_expanded + assert not self.renorm_padding or incremental_state is not None + + if incremental_state is not None: + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) + else: + padding_l = self.padding_l + if K > T and padding_l == K - 1: + weight = weight.narrow(1, K - T, T) + K, padding_l = T, T - 1 + # unfold the input: T x B x C --> T' x B x C x K + x_unfold = unfold1d(x, K, padding_l, 0) + x_unfold = x_unfold.view(T * B * H, R, K) + + if self.weight_softmax and not self.renorm_padding: + weight = F.softmax(weight, dim=1) + weight = weight.narrow(1, 0, K) + + if incremental_state is not None: + weight = weight[:, -x_unfold.size(2) :] + K = weight.size(1) + + if self.weight_softmax and self.renorm_padding: + weight = F.softmax(weight, dim=1) + + weight = self.weight_dropout_module(weight, inplace=False) + + output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + def _forward_expanded(self, x, incremental_stat, query): + """Turn the convolution filters into band matrices and do matrix multiplication. + This is faster when the sequence is short, but less memory efficient. + This is not used in the decoder during inference. + """ + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + if self.in_proj: + proj = self.weight_linear(x) + x = proj.narrow(2, 0, self.input_size).contiguous() + weight = ( + proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1) + ) + else: + weight = self.weight_linear(query).view(T * B * H, -1) + + if not self.renorm_padding: + if self.weight_softmax: + weight = F.softmax(weight, dim=1) + weight = self.weight_dropout_module(weight, inplace=False) + weight = weight.narrow(1, 0, K).contiguous() + weight = weight.view(T, B * H, K).transpose(0, 1) + + x = x.view(T, B * H, R).transpose(0, 1) + if self.weight_softmax and self.renorm_padding: + # turn the convolution filters into band matrices + weight_expanded = weight.new(B * H, T, T + K - 1).fill_(float("-inf")) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) + weight_expanded = weight_expanded.narrow(2, self.padding_l, T) + # normalize the weight over valid positions like self-attention + weight_expanded = F.softmax(weight_expanded, dim=2) + weight_expanded = self.weight_dropout_module(weight_expanded, inplace=False) + else: + P = self.padding_l + # For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length + if K > T and P == K - 1: + weight = weight.narrow(2, K - T, T) + K, P = T, T - 1 + # turn the convolution filters into band matrices + weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) + weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T + output = torch.bmm(weight_expanded, x) + output = output.transpose(0, 1).contiguous().view(T, B, C) + return output + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, "input_buffer") + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) + + def extra_repr(self): + s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}".format( + self.input_size, + self.kernel_size, + self.padding_l, + self.num_heads, + self.weight_softmax, + self.conv_bias is not None, + self.renorm_padding, + self.in_proj, + ) + + if self.query_size != self.input_size: + s += ", query_size={}".format(self.query_size) + if self.weight_dropout_module.p > 0.0: + s += ", weight_dropout={}".format(self.weight_dropout_module.p) + return s diff --git a/fairseq/modules/dynamic_crf_layer.py b/fairseq/modules/dynamic_crf_layer.py new file mode 100644 index 0000000..8fcc6b8 --- /dev/null +++ b/fairseq/modules/dynamic_crf_layer.py @@ -0,0 +1,189 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This file is to re-implemented the low-rank and beam approximation of CRF layer +Proposed by: + +Sun, Zhiqing, et al. +Fast Structured Decoding for Sequence Models +https://arxiv.org/abs/1910.11555 + +The CRF implementation is mainly borrowed from +https://github.com/kmkurn/pytorch-crf/blob/master/torchcrf/__init__.py + +""" + +import numpy as np +import torch +import torch.nn as nn + + +def logsumexp(x, dim=1): + return torch.logsumexp(x.float(), dim=dim).type_as(x) + + +class DynamicCRF(nn.Module): + """Dynamic CRF layer is used to approximate the traditional + Conditional Random Fields (CRF) + $P(y | x) = 1/Z(x) exp(sum_i s(y_i, x) + sum_i t(y_{i-1}, y_i, x))$ + + where in this function, we assume the emition scores (s) are given, + and the transition score is a |V| x |V| matrix $M$ + + in the following two aspects: + (1) it used a low-rank approximation for the transition matrix: + $M = E_1 E_2^T$ + (2) it used a beam to estimate the normalizing factor Z(x) + """ + + def __init__(self, num_embedding, low_rank=32, beam_size=64): + super().__init__() + + self.E1 = nn.Embedding(num_embedding, low_rank) + self.E2 = nn.Embedding(num_embedding, low_rank) + + self.vocb = num_embedding + self.rank = low_rank + self.beam = beam_size + + def extra_repr(self): + return "vocab_size={}, low_rank={}, beam_size={}".format( + self.vocb, self.rank, self.beam + ) + + def forward(self, emissions, targets, masks, beam=None): + """ + Compute the conditional log-likelihood of a sequence of target tokens given emission scores + + Args: + emissions (`~torch.Tensor`): Emission score are usually the unnormalized decoder output + ``(batch_size, seq_len, vocab_size)``. We assume batch-first + targets (`~torch.LongTensor`): Sequence of target token indices + ``(batch_size, seq_len) + masks (`~torch.ByteTensor`): Mask tensor with the same size as targets + + Returns: + `~torch.Tensor`: approximated log-likelihood + """ + numerator = self._compute_score(emissions, targets, masks) + denominator = self._compute_normalizer(emissions, targets, masks, beam) + return numerator - denominator + + def forward_decoder(self, emissions, masks=None, beam=None): + """ + Find the most likely output sequence using Viterbi algorithm. + + Args: + emissions (`~torch.Tensor`): Emission score are usually the unnormalized decoder output + ``(batch_size, seq_len, vocab_size)``. We assume batch-first + masks (`~torch.ByteTensor`): Mask tensor with the same size as targets + + Returns: + `~torch.LongTensor`: decoded sequence from the CRF model + """ + return self._viterbi_decode(emissions, masks, beam) + + def _compute_score(self, emissions, targets, masks=None): + batch_size, seq_len = targets.size() + emission_scores = emissions.gather(2, targets[:, :, None])[:, :, 0] # B x T + transition_scores = (self.E1(targets[:, :-1]) * self.E2(targets[:, 1:])).sum(2) + + scores = emission_scores + scores[:, 1:] += transition_scores + + if masks is not None: + scores = scores * masks.type_as(scores) + return scores.sum(-1) + + def _compute_normalizer(self, emissions, targets=None, masks=None, beam=None): + # HACK: we include "target" which is a hueristic for training + # HACK: we use a beam of tokens to approximate the normalizing factor (which is bad?) + + beam = beam if beam is not None else self.beam + batch_size, seq_len = emissions.size()[:2] + if targets is not None: + _emissions = emissions.scatter(2, targets[:, :, None], np.float("inf")) + beam_targets = _emissions.topk(beam, 2)[1] + beam_emission_scores = emissions.gather(2, beam_targets) + else: + beam_emission_scores, beam_targets = emissions.topk(beam, 2) + beam_transition_score1 = self.E1(beam_targets[:, :-1]) # B x (T-1) x K x D + beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D + beam_transition_matrix = torch.bmm( + beam_transition_score1.view(-1, beam, self.rank), + beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2), + ) + beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) + + # compute the normalizer in the log-space + score = beam_emission_scores[:, 0] # B x K + for i in range(1, seq_len): + next_score = score[:, :, None] + beam_transition_matrix[:, i - 1] + next_score = logsumexp(next_score, dim=1) + beam_emission_scores[:, i] + + if masks is not None: + score = torch.where(masks[:, i : i + 1], next_score, score) + else: + score = next_score + + # Sum (log-sum-exp) over all possible tags + return logsumexp(score, dim=1) + + def _viterbi_decode(self, emissions, masks=None, beam=None): + # HACK: we use a beam of tokens to approximate the normalizing factor (which is bad?) + + beam = beam if beam is not None else self.beam + batch_size, seq_len = emissions.size()[:2] + beam_emission_scores, beam_targets = emissions.topk(beam, 2) + beam_transition_score1 = self.E1(beam_targets[:, :-1]) # B x (T-1) x K x D + beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D + beam_transition_matrix = torch.bmm( + beam_transition_score1.view(-1, beam, self.rank), + beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2), + ) + beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) + + traj_tokens, traj_scores = [], [] + finalized_tokens, finalized_scores = [], [] + + # compute the normalizer in the log-space + score = beam_emission_scores[:, 0] # B x K + dummy = ( + torch.arange(beam, device=score.device).expand(*score.size()).contiguous() + ) + + for i in range(1, seq_len): + traj_scores.append(score) + _score = score[:, :, None] + beam_transition_matrix[:, i - 1] + _score, _index = _score.max(dim=1) + _score = _score + beam_emission_scores[:, i] + + if masks is not None: + score = torch.where(masks[:, i : i + 1], _score, score) + index = torch.where(masks[:, i : i + 1], _index, dummy) + else: + score, index = _score, _index + traj_tokens.append(index) + + # now running the back-tracing and find the best + best_score, best_index = score.max(dim=1) + finalized_tokens.append(best_index[:, None]) + finalized_scores.append(best_score[:, None]) + + for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)): + previous_index = finalized_tokens[-1] + finalized_tokens.append(idx.gather(1, previous_index)) + finalized_scores.append(scs.gather(1, previous_index)) + + finalized_tokens.reverse() + finalized_tokens = torch.cat(finalized_tokens, 1) + finalized_tokens = beam_targets.gather(2, finalized_tokens[:, :, None])[:, :, 0] + + finalized_scores.reverse() + finalized_scores = torch.cat(finalized_scores, 1) + finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1] + + return finalized_scores, finalized_tokens diff --git a/fairseq/modules/dynamicconv_layer/__init__.py b/fairseq/modules/dynamicconv_layer/__init__.py new file mode 100644 index 0000000..22dc6f4 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .dynamicconv_layer import DynamicconvLayer # noqa diff --git a/fairseq/modules/dynamicconv_layer/cuda_function_gen.py b/fairseq/modules/dynamicconv_layer/cuda_function_gen.py new file mode 100644 index 0000000..9304f99 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/cuda_function_gen.py @@ -0,0 +1,223 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +def gen_forward(): + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + blocks = [32, 64, 128, 256] + + head = """ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dynamicconv_cuda.cuh" + +std::vector dynamicconv_cuda_forward(at::Tensor input, at::Tensor weight, int padding_l) { + + at::DeviceGuard g(input.device()); + const auto minibatch = input.size(0); + const auto numFeatures = input.size(1); + const auto sequenceLength = input.size(2); + + const auto numHeads = weight.size(1); + const auto filterSize = weight.size(2); + + const auto numFiltersInBlock = numFeatures / numHeads; + const dim3 blocks(minibatch, numFeatures); + + auto output = at::zeros_like(input); + auto stream = at::cuda::getCurrentCUDAStream(); +""" + + switch = """ + switch(filterSize) { +""" + + case_k = """ + case {k}: +""" + + main_block = """ + if (padding_l == {pad}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "dynamicconv_forward", ([&] {{ + dynamicconv_forward_kernel<{k}, {b_size}, {pad}, scalar_t> + <<>>( + input.data(), + weight.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + numHeads, + output.data()); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping forward pass" << std::endl; + } + break;\n +""" + + end = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping forward pass" << std::endl; + } + + return {output}; +} +""" + + with open("dynamicconv_cuda_forward.cu", "w") as forward: + forward.write(head) + forward.write(switch) + for k in kernels: + b_size = 32 + for b in blocks: + if b > k: + b_size = b + break + forward.write(case_k.format(k=k)) + for pad in [k // 2, k - 1]: + forward.write(main_block.format(k=k, b_size=b_size, pad=pad)) + forward.write(bad_padding) + forward.write(end) + + +def gen_backward(): + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + thresh = [512, 512, 512, 512, 512, 380, 256, 256] + min_block = [64, 64, 64, 64, 64, 64, 128, 256] + seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] + + head = """ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dynamicconv_cuda.cuh" + +std::vector dynamicconv_cuda_backward(at::Tensor gradOutput, int padding_l, at::Tensor input, at::Tensor weight) { + + at::DeviceGuard g(input.device()); + const auto minibatch = input.size(0); + const auto numFeatures = input.size(1); + const auto sequenceLength = input.size(2); + + const auto numHeads = weight.size(1); + const auto filterSize = weight.size(2); + + const auto numFiltersInBlock = numFeatures / numHeads; + auto numChunks = 1; + + auto gradInput = at::zeros_like(input); + auto gradWeight = at::zeros_like(weight); + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(minibatch, numHeads, numChunks); +""" + + sequence_if = """ + if (sequenceLength < {seq}) {{ + switch(filterSize) {{ +""" + + case_k = """ + case {k}: +""" + + chunks_reset = """ + numChunks = int(ceilf(sequenceLength/float({b_size}))); + blocks = dim3(minibatch, numHeads, numChunks); +""" + + main_block = """ + if (padding_l == {p}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradOutput.scalar_type(), "dynamicconv_backward", ([&] {{ + dynamicconv_backward_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + gradOutput.data(), + input.data(), + weight.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + numHeads, + gradWeight.data(), + gradInput.data()); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping backward pass" << std::endl; + } + break;\n +""" + + bad_filter = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping backward pass" << std::endl; + } +""" + + con_else = """ + } else +""" + + final_else = """ + { + switch(filterSize) { +""" + + last_return = """ + } + return {gradInput, gradWeight}; +} +""" + + with open("dynamicconv_cuda_backward.cu", "w") as backward: + backward.write(head) + for seq in seqs: + backward.write(sequence_if.format(seq=seq)) + for k, t, m in zip(kernels, thresh, min_block): + backward.write(case_k.format(k=k)) + if seq <= t: + b_size = seq + else: + b_size = m + backward.write(chunks_reset.format(b_size=b_size)) + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=b_size, p=p)) + backward.write(bad_padding) + backward.write(bad_filter) + backward.write(con_else) + backward.write(final_else) + for k, m in zip(kernels, min_block): + backward.write(case_k.format(k=k)) + backward.write(chunks_reset.format(b_size=m)) + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=m, p=p)) + backward.write(bad_padding) + backward.write(bad_filter) + backward.write(last_return) + + +if __name__ == "__main__": + gen_forward() + gen_backward() diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp new file mode 100644 index 0000000..ebd4df0 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp @@ -0,0 +1,56 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +std::vector dynamicconv_cuda_forward( + at::Tensor input, + at::Tensor filters, + int padding_l); + +std::vector dynamicconv_cuda_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters); + + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector dynamicconv_forward( + at::Tensor input, + at::Tensor filters, + int padding_l) { + + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return dynamicconv_cuda_forward(input, filters, + padding_l); +} + +std::vector dynamicconv_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters) { + + CHECK_INPUT(gradOutput); + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return dynamicconv_cuda_backward(gradOutput, padding_l, + input, filters); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)"); + m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)"); +} diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh new file mode 100644 index 0000000..2196259 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh @@ -0,0 +1,51 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define SHFL_MASK 0xffffffff + +template +__global__ +void dynamicconv_forward_kernel(const scalar_t* input, + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* output); + +template +__global__ +void dynamicconv_backward_kernel( + const scalar_t* gradOutput, // B * C * T + const scalar_t* input, // B * C * T + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* gradWeight, + scalar_t* gradInput); // B * H * k * T diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu new file mode 100644 index 0000000..300d35b --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu @@ -0,0 +1,168 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dynamicconv_cuda.cuh" +#include "dynamicconv_cuda_forward.cu" +#include "dynamicconv_cuda_backward.cu" +#include "../cuda_utils.cu" + +// FS is filter size and kernels are specialized for filter sizes +template +__global__ +void dynamicconv_forward_kernel(const scalar_t* input, + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* output) { + assert(blockDim.x == SB); + + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int head = featureIdx / numFiltersInBlock; + + const int IOOffset = batchIdx * numFeatures * sequenceLength + + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + scalar_t* outputFeature = &output[IOOffset]; + + scalar_t filter[FS]; + + __shared__ scalar_t tempInput[SB + FS]; + zeroSharedMem(tempInput); + + const int numIterations = divUp(sequenceLength, SB); + + for (int i = 0; i < numIterations; ++i) { + __syncthreads(); + const int inputOffset = i * SB; + load_input_to_shared(inputFeature, inputOffset, + sequenceLength, i, + numIterations, false, tempInput); + __syncthreads(); + if (inputOffset + tid < sequenceLength) { + + #pragma unroll + for (int k = 0; k < FS; ++k) { + const int filterOffset = batchIdx * numHeads * FS * sequenceLength + + head * FS * sequenceLength + + k * sequenceLength + + i * SB + tid; + filter[k] = weight[filterOffset]; + } + + scalar_t out = scalar_t(0.0); + #pragma unroll + for (int k = 0; k < FS; ++k) { + out += filter[k] * tempInput[tid + k]; + } + + outputFeature[inputOffset + tid] = out; + + } + } +} + +template +__global__ +void dynamicconv_backward_kernel( + const scalar_t* gradOutput, // B * C * T + const scalar_t* input, // B * C * T + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* gradWeight, + scalar_t* gradInput) { // B * H * k * T + + assert(blockDim.x == SB); + + // each block operates on a single batch and filter head + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int headIdx = blockIdx.y; + const int chunkIdx = blockIdx.z; + + const int numChunks = divUp(sequenceLength, SB); + const int inputOffset = chunkIdx * SB; + + // initialize shared memory for output gradient and input + __shared__ scalar_t tempGradOutput[SB + FS]; + __shared__ scalar_t tempInput[SB + FS]; + const int padding = FS - padding_l - 1; + + zeroSharedMem(tempGradOutput); + zeroSharedMem(tempInput); + + // initialize local filter and weight gradient sum arrays + scalar_t tempGradSum[FS]; + scalar_t bfilter[FS]; + for (int k = 0; k < FS; ++k) { + tempGradSum[k] = scalar_t(0.0); + + int idxOffset = inputOffset + tid + k - padding; + if (idxOffset >= 0 && idxOffset < sequenceLength) { + int bfilterOffset = batchIdx * numHeads * FS * sequenceLength + + headIdx * FS * sequenceLength + + (FS - k - 1) * sequenceLength + + idxOffset; + bfilter[k] = weight[bfilterOffset]; + } else { + bfilter[k] = scalar_t(0.0); + } + } + + + // iterate over filter block + for (int featureIdx = 0; featureIdx < numFiltersInBlock; ++featureIdx) { + __syncthreads(); + + // load input and output gradient for this channel and chunk + const int IOOffset = batchIdx * numFeatures * sequenceLength + + (headIdx * numFiltersInBlock + featureIdx) * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + const scalar_t* gradOutputFeature = &gradOutput[IOOffset]; + scalar_t* gradInputFeature = &gradInput[IOOffset]; + + load_input_to_shared(gradOutputFeature, inputOffset, + sequenceLength, chunkIdx, + numChunks, true, tempGradOutput); + load_input_to_shared(inputFeature, inputOffset, + sequenceLength, chunkIdx, + numChunks, true, tempInput); + __syncthreads(); + + // sum input and weight gradients + scalar_t out = scalar_t(0.0); + #pragma unroll + for (int k = 0; k < FS; ++k) { + tempGradSum[k] += tempInput[tid + k] * tempGradOutput[tid + padding]; + out += bfilter[k] * tempGradOutput[tid + k]; + } + + if (inputOffset + tid < sequenceLength) { + gradInputFeature[inputOffset + tid] = out; + } + } + + const int gradOffset = batchIdx * numHeads * FS * sequenceLength + + headIdx * FS * sequenceLength; + scalar_t *gradWeightFeature = &gradWeight[gradOffset]; + + // write weight gradient + if (inputOffset + tid < sequenceLength) { + for (int k = 0; k < FS; ++k) { + const int outputOffset = k * sequenceLength + inputOffset + tid; + gradWeightFeature[outputOffset] = tempGradSum[k]; + } + } +} diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py new file mode 100644 index 0000000..4a683d2 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py @@ -0,0 +1,227 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import dynamicconv_cuda +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.unfold import unfold1d +from torch import nn +from torch.autograd import Function + + +class dynamicconvFunction(Function): + @staticmethod + def forward(ctx, x, weights, padding_l): + ctx.padding_l = padding_l + outputs = dynamicconv_cuda.forward(x, weights, padding_l) + variables = [x, weights] + ctx.save_for_backward(*variables) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output): + outputs = dynamicconv_cuda.backward( + grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors + ) + grad_input, grad_weights = outputs + return grad_input, grad_weights, None + + +@with_incremental_state +class DynamicconvLayer(nn.Module): + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + weight_softmax=False, + num_heads=1, + weight_dropout=0.0, + bias=False, + renorm_padding=False, + conv_bias=False, + query_size=None, + ): + + super(DynamicconvLayer, self).__init__() + self.input_size = input_size + self.query_size = input_size if query_size is None else query_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_softmax = weight_softmax + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) + self.renorm_padding = renorm_padding + self.bias = bias + + self.weight_linear = nn.Linear(input_size, num_heads * kernel_size, bias) + if conv_bias: + self.conv_bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.conv_bias = None + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight_linear.weight) + if self.conv_bias is not None: + nn.init.constant_(self.conv_bias, 0.0) + nn.init.constant_(self.weight_linaer.bias, 0.0) + + def forward(self, x, incremental_state=None, query=None, unfold=None): + + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + # R = C // H + + # during inference time, incremental BMM is faster + if incremental_state is not None: + unfold = ( + x.size(0) > 512 if unfold is None else unfold + ) # use unfold mode as default for long sequence to save memory + unfold = unfold or (incremental_state is not None) + assert query is None + + if query is None: + query = x + if unfold: + output = self._forward_unfolded(x, incremental_state, query) + else: + output = self._forward_expanded(x, incremental_state, query) + + if self.conv_bias is not None: + output = output + self.conv_bias.view(1, 1, -1) + + return output + + # during training time, use CUDA kernel + else: + weight = self.weight_linear(x).view(T, B, H, K) + if self.weight_softmax: + weight = F.softmax(weight, dim=-1) + if self.weight_dropout_module.p: + weight = self.weight_dropout_module(weight) + + weight = weight.permute(1, 2, 3, 0).contiguous() + self.filters = weight + x = x.permute(1, 2, 0).contiguous() + output = dynamicconvFunction.apply(x, weight, self.padding_l).permute( + 2, 0, 1 + ) + if self.conv_bias is not None: + output = output + self.conv_bias.view(1, 1, -1) + return output + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, "input_buffer") + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) + + def _forward_unfolded(self, x, incremental_state, query): + """The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.""" + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + weight = self.weight_linear(query).view(T * B * H, -1) + + # renorm_padding is only implemented in _forward_expanded + assert not self.renorm_padding or incremental_state is not None + + if incremental_state is not None: + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) + else: + padding_l = self.padding_l + if K > T and padding_l == K - 1: + weight = weight.narrow(1, K - T, T) + K, padding_l = T, T - 1 + # unfold the input: T x B x C --> T' x B x C x K + x_unfold = unfold1d(x, K, padding_l, 0) + x_unfold = x_unfold.view(T * B * H, R, K) + + if self.weight_softmax and not self.renorm_padding: + weight = F.softmax(weight, dim=1) + weight = weight.narrow(1, 0, K) + + if incremental_state is not None: + weight = weight[:, -x_unfold.size(2) :] + K = weight.size(1) + + if self.weight_softmax and self.renorm_padding: + weight = F.softmax(weight, dim=1) + + weight = self.weight_dropout_module(weight, inplace=False) + + output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + def _forward_expanded(self, x, incremental_stat, query): + """Turn the convolution filters into band matrices and do matrix multiplication. + This is faster when the sequence is short, but less memory efficient. + This is not used in the decoder during inference. + """ + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + weight = self.weight_linear(query).view(T * B * H, -1) + + if not self.renorm_padding: + if self.weight_softmax: + weight = F.softmax(weight, dim=1) + weight = self.weight_dropout_module(weight, inplace=False) + weight = weight.narrow(1, 0, K).contiguous() + weight = weight.view(T, B * H, K).transpose(0, 1) + + x = x.view(T, B * H, R).transpose(0, 1) + if self.weight_softmax and self.renorm_padding: + # turn the convolution filters into band matrices + weight_expanded = weight.new(B * H, T, T + K - 1).fill_(float("-inf")) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) + weight_expanded = weight_expanded.narrow(2, self.padding_l, T) + # normalize the weight over valid positions like self-attention + weight_expanded = F.softmax(weight_expanded, dim=2) + weight_expanded = self.weight_dropout_module(weight_expanded, inplace=False) + else: + P = self.padding_l + # For efficieny, we cut the kernel size and reduce the padding when the kernel is larger than the length + if K > T and P == K - 1: + weight = weight.narrow(2, K - T, T) + K, P = T, T - 1 + # turn the convolution filters into band matrices + weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) + weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T + output = torch.bmm(weight_expanded, x) + output = output.transpose(0, 1).contiguous().view(T, B, C) + return output diff --git a/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp b/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp new file mode 100644 index 0000000..8a6af42 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp @@ -0,0 +1,35 @@ +#include +#include + +std::vector dynamicconv_cpu_forward( + float* input, + float* filters, + int padding_l); + +std::vector dynamicconv_cpu_backward( + float* gradOutput, + int padding_l, + float* input, + float* filters); + +std::vector dynamicconv_forward( + float* input, + float* filters, + int padding_l) { + + return dynamicconv_cpu_forward(input, filters, padding_l); +} + +std::vector dynamicconv_backward( + float* gradOutput, + int padding_l, + float* input, + float* filters) { + + return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)"); + m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)"); +} diff --git a/fairseq/modules/dynamicconv_layer/setup.py b/fairseq/modules/dynamicconv_layer/setup.py new file mode 100644 index 0000000..6a21f7e --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/setup.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name="dynamicconv_layer", + ext_modules=[ + CUDAExtension( + name="dynamicconv_cuda", + sources=[ + "dynamicconv_cuda.cpp", + "dynamicconv_cuda_kernel.cu", + ], + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/fairseq/modules/fairseq_dropout.py b/fairseq/modules/fairseq_dropout.py new file mode 100644 index 0000000..f070a80 --- /dev/null +++ b/fairseq/modules/fairseq_dropout.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Optional + +import torch.nn as nn +import torch.nn.functional as F + + +logger = logging.getLogger(__name__) + + +class FairseqDropout(nn.Module): + def __init__(self, p, module_name=None): + super().__init__() + self.p = p + self.module_name = module_name + self.apply_during_inference = False + + def forward(self, x, inplace: bool = False): + if self.training or self.apply_during_inference: + return F.dropout(x, p=self.p, training=True, inplace=inplace) + else: + return x + + def make_generation_fast_( + self, + name: str, + retain_dropout: bool = False, + retain_dropout_modules: Optional[List[str]] = None, + **kwargs + ): + if retain_dropout: + if retain_dropout_modules is not None and self.module_name is None: + logger.warning( + "Cannot enable dropout during inference for module {} " + "because module_name was not set".format(name) + ) + elif ( + retain_dropout_modules is None # if None, apply to all modules + or self.module_name in retain_dropout_modules + ): + logger.info( + "Enabling dropout during inference for module: {}".format(name) + ) + self.apply_during_inference = True + else: + logger.info("Disabling dropout for module: {}".format(name)) diff --git a/fairseq/modules/fp32_group_norm.py b/fairseq/modules/fp32_group_norm.py new file mode 100644 index 0000000..d03aac0 --- /dev/null +++ b/fairseq/modules/fp32_group_norm.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Layer norm done in fp32 (for fp16 training) +""" + +import torch.nn as nn +import torch.nn.functional as F + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) diff --git a/fairseq/modules/gelu.py b/fairseq/modules/gelu.py new file mode 100644 index 0000000..a2f1ecf --- /dev/null +++ b/fairseq/modules/gelu.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with +the corresponding GitHub repo: https://github.com/hendrycks/GELUs +""" + +import math + +import torch +import torch.nn as nn + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) diff --git a/fairseq/modules/grad_multiply.py b/fairseq/modules/grad_multiply.py new file mode 100644 index 0000000..08d15f5 --- /dev/null +++ b/fairseq/modules/grad_multiply.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py new file mode 100644 index 0000000..47657bb --- /dev/null +++ b/fairseq/modules/gumbel_vector_quantizer.py @@ -0,0 +1,199 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GumbelVectorQuantizer(nn.Module): + def __init__( + self, + dim, + num_vars, + temp, + groups, + combine_groups, + vq_dim, + time_first, + activation=nn.GELU(), + weight_proj_depth=1, + weight_proj_factor=1, + ): + """Vector quantization using gumbel softmax + + Args: + dim: input dimension (channels) + num_vars: number of quantized vectors per group + temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor) + groups: number of groups for vector quantization + combine_groups: whether to use the vectors for all groups + vq_dim: dimensionality of the resulting quantized vector + time_first: if true, expect input in BxTxC format, otherwise in BxCxT + activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1 + weight_proj_depth: number of layers (with activation in between) to project input before computing logits + weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of + projections by this factor + """ + super().__init__() + + self.groups = groups + self.combine_groups = combine_groups + self.input_dim = dim + self.num_vars = num_vars + self.time_first = time_first + + assert ( + vq_dim % groups == 0 + ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" + + var_dim = vq_dim // groups + num_groups = groups if not combine_groups else 1 + + self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim)) + nn.init.uniform_(self.vars) + + if weight_proj_depth > 1: + + def block(input_dim, output_dim): + return nn.Sequential(nn.Linear(input_dim, output_dim), activation) + + inner_dim = self.input_dim * weight_proj_factor + self.weight_proj = nn.Sequential( + *[ + block(self.input_dim if i == 0 else inner_dim, inner_dim) + for i in range(weight_proj_depth - 1) + ], + nn.Linear(inner_dim, groups * num_vars), + ) + else: + self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) + nn.init.normal_(self.weight_proj.weight, mean=0, std=1) + nn.init.zeros_(self.weight_proj.bias) + + assert len(temp) == 3, temp + + self.max_temp, self.min_temp, self.temp_decay = temp + self.curr_temp = self.max_temp + self.codebook_indices = None + + def set_num_updates(self, num_updates): + self.curr_temp = max( + self.max_temp * self.temp_decay ** num_updates, self.min_temp + ) + + def get_codebook_indices(self): + if self.codebook_indices is None: + from itertools import product + + p = [range(self.num_vars)] * self.groups + inds = list(product(*p)) + self.codebook_indices = torch.tensor( + inds, dtype=torch.long, device=self.vars.device + ).flatten() + + if not self.combine_groups: + self.codebook_indices = self.codebook_indices.view( + self.num_vars ** self.groups, -1 + ) + for b in range(1, self.groups): + self.codebook_indices[:, b] += self.num_vars * b + self.codebook_indices = self.codebook_indices.flatten() + return self.codebook_indices + + def codebook(self): + indices = self.get_codebook_indices() + return ( + self.vars.squeeze(0) + .index_select(0, indices) + .view(self.num_vars ** self.groups, -1) + ) + + def sample_from_codebook(self, b, n): + indices = self.get_codebook_indices() + indices = indices.view(-1, self.groups) + cb_size = indices.size(0) + assert ( + n < cb_size + ), f"sample size {n} is greater than size of codebook {cb_size}" + sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,)) + indices = indices[sample_idx] + + z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1) + return z + + def to_codebook_index(self, indices): + res = indices.new_full(indices.shape[:-1], 0) + for i in range(self.groups): + exponent = self.groups - i - 1 + res += indices[..., i] * (self.num_vars ** exponent) + return res + + def forward_idx(self, x): + res = self.forward(x, produce_targets=True) + return res["x"], res["targets"] + + def forward(self, x, produce_targets=False): + + result = {"num_vars": self.num_vars * self.groups} + + if not self.time_first: + x = x.transpose(1, 2) + + bsz, tsz, fsz = x.shape + x = x.reshape(-1, fsz) + x = self.weight_proj(x) + x = x.view(bsz * tsz * self.groups, -1) + + _, k = x.max(-1) + hard_x = ( + x.new_zeros(*x.shape) + .scatter_(-1, k.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() + + avg_probs = torch.softmax( + x.view(bsz * tsz, self.groups, -1).float(), dim=-1 + ).mean(dim=0) + result["prob_perplexity"] = torch.exp( + -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) + ).sum() + + result["temp"] = self.curr_temp + + if self.training: + x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=True).type_as(x) + else: + x = hard_x + + x = x.view(bsz * tsz, -1) + + vars = self.vars + if self.combine_groups: + vars = vars.repeat(1, self.groups, 1) + + if produce_targets: + result["targets"] = ( + x.view(bsz * tsz * self.groups, -1) + .argmax(dim=-1) + .view(bsz, tsz, self.groups) + .detach() + ) + + x = x.unsqueeze(-1) * vars + x = x.view(bsz * tsz, self.groups, self.num_vars, -1) + x = x.sum(-2) + x = x.view(bsz, tsz, -1) + + if not self.time_first: + x = x.transpose(1, 2) # BTC -> BCT + + result["x"] = x + + return result diff --git a/fairseq/modules/kmeans_vector_quantizer.py b/fairseq/modules/kmeans_vector_quantizer.py new file mode 100644 index 0000000..040db1e --- /dev/null +++ b/fairseq/modules/kmeans_vector_quantizer.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from fairseq.modules import Fp32GroupNorm + + +class KmeansVectorQuantizer(nn.Module): + def __init__( + self, dim, num_vars, groups, combine_groups, vq_dim, time_first, gamma=0.25 + ): + """Vector quantization using straight pass-through estimator (i.e. kmeans) + + Args: + dim: input dimension (channels) + num_vars: number of quantized vectors per group + groups: number of groups for vector quantization + combine_groups: whether to use the vectors for all groups + vq_dim: dimensionality of the resulting quantized vector + time_first: if true, expect input in BxTxC format, otherwise in BxCxT + gamma: commitment loss coefficient + """ + super().__init__() + + self.groups = groups + self.combine_groups = combine_groups + self.input_dim = dim + self.num_vars = num_vars + self.vq_dim = vq_dim + self.time_first = time_first + + assert ( + vq_dim % groups == 0 + ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" + + self.var_dim = vq_dim // groups + num_groups = groups if not combine_groups else 1 + + self.embedding = nn.Parameter( + 0.01 * torch.randn(num_vars, num_groups, self.var_dim) + ) + self.projection = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size=1, groups=groups, bias=False), + Fp32GroupNorm(groups, dim), + ) + self.gamma = gamma + self.mse_mean = nn.MSELoss(reduction="mean") + + def _pass_grad(self, x, y): + """Manually set gradient for backward pass. + for y = f(x), ensure that during the backward pass, + dL/dy = dL/dx regardless of f(x). + Returns: + y, with the gradient forced to be dL/dy = dL/dx. + """ + + return y.detach() + (x - x.detach()) + + @property + def expand_embedding(self): + if self.combine_groups: + return self.embedding.expand(self.num_vars, self.groups, self.var_dim) + return self.embedding + + def forward_idx(self, x): + res = self.forward(x, produce_targets=True) + return res["x"], res["targets"] + + def forward(self, x, produce_targets=False): + + result = {"num_vars": self.num_vars} + + if self.time_first: + x = x.transpose(1, 2) + + bsz, fsz, tsz = x.shape + + ze = self.projection(x) + ze_ = ze.view(bsz, self.groups, self.var_dim, tsz).permute(0, 3, 1, 2) + d = ( + (ze_.unsqueeze(0) - self.expand_embedding.unsqueeze(1).unsqueeze(1)) + .view(self.num_vars, bsz, tsz, self.groups, -1) + .norm(dim=-1, p=2) + ) + idx = d.argmin(dim=0) + zq = ( + torch.stack( + [ + self.expand_embedding[idx[..., group], group] + for group in range(self.groups) + ], + dim=-2, + ) + .view(bsz, tsz, self.groups * self.var_dim) + .permute(0, 2, 1) + ) + assert ze.shape == zq.shape, (ze.shape, zq.shape) + x = self._pass_grad(ze, zq) + + hard_x = ( + idx.new_zeros(bsz * tsz * self.groups, self.num_vars) + .scatter_(-1, idx.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() + + if produce_targets: + result["targets"] = idx + + if self.time_first: + x = x.transpose(1, 2) # BCT -> BTC + result["x"] = x + + ze = ze.float() + zq = zq.float() + latent_loss = self.mse_mean(zq, ze.detach()) + commitment_loss = self.mse_mean(ze, zq.detach()) + + result["kmeans_loss"] = latent_loss + self.gamma * commitment_loss + + return result diff --git a/fairseq/modules/layer_drop.py b/fairseq/modules/layer_drop.py new file mode 100644 index 0000000..8961d8b --- /dev/null +++ b/fairseq/modules/layer_drop.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +LayerDrop as described in https://arxiv.org/abs/1909.11556. +""" + +import torch +import torch.nn as nn + + +class LayerDropModuleList(nn.ModuleList): + """ + A LayerDrop implementation based on :class:`torch.nn.ModuleList`. + + We refresh the choice of which layers to drop every time we iterate + over the LayerDropModuleList instance. During evaluation we always + iterate over all layers. + + Usage:: + + layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) + for layer in layers: # this might iterate over layers 1 and 3 + x = layer(x) + for layer in layers: # this might iterate over all layers + x = layer(x) + for layer in layers: # this might not iterate over any layers + x = layer(x) + + Args: + p (float): probability of dropping out each layer + modules (iterable, optional): an iterable of modules to add + """ + + def __init__(self, p, modules=None): + super().__init__(modules) + self.p = p + + def __iter__(self): + dropout_probs = torch.empty(len(self)).uniform_() + for i, m in enumerate(super().__iter__()): + if not self.training or (dropout_probs[i] > self.p): + yield m diff --git a/fairseq/modules/layer_norm.py b/fairseq/modules/layer_norm.py new file mode 100644 index 0000000..234609d --- /dev/null +++ b/fairseq/modules/layer_norm.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +try: + from apex.normalization import FusedLayerNorm as _FusedLayerNorm + + has_fused_layernorm = True + + class FusedLayerNorm(_FusedLayerNorm): + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) + + +except ImportError: + has_fused_layernorm = False + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + if torch.jit.is_scripting(): + export = True + if not export and torch.cuda.is_available() and has_fused_layernorm: + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) diff --git a/fairseq/modules/learned_positional_embedding.py b/fairseq/modules/learned_positional_embedding.py new file mode 100644 index 0000000..378d0f7 --- /dev/null +++ b/fairseq/modules/learned_positional_embedding.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from torch import Tensor + + +class LearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + Padding ids are ignored by either offsetting based on padding_idx + or by setting padding_idx to None and ensuring that the appropriate + position ids are passed to the forward function. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.onnx_trace = False + if self.padding_idx is not None: + self.max_positions = self.num_embeddings - self.padding_idx - 1 + else: + self.max_positions = self.num_embeddings + + def forward( + self, + input: Tensor, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + positions: Optional[Tensor] = None, + ): + """Input is expected to be of size [bsz x seqlen].""" + assert (positions is None) or ( + self.padding_idx is None + ), "If positions is pre-computed then padding_idx should not be set." + + if positions is None: + if incremental_state is not None: + # positions is the same for every token when decoding a single step + # Without the int() cast, it doesn't work in some cases when exporting to ONNX + positions = torch.zeros( + (1, 1), device=input.device, dtype=input.dtype + ).fill_(int(self.padding_idx + input.size(1))) + else: + positions = utils.make_positions( + input, self.padding_idx, onnx_trace=self.onnx_trace + ) + return F.embedding( + positions, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) diff --git a/fairseq/modules/lightconv_layer/__init__.py b/fairseq/modules/lightconv_layer/__init__.py new file mode 100644 index 0000000..3b2a99c --- /dev/null +++ b/fairseq/modules/lightconv_layer/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .lightconv_layer import LightconvLayer # noqa diff --git a/fairseq/modules/lightconv_layer/cuda_function_gen.py b/fairseq/modules/lightconv_layer/cuda_function_gen.py new file mode 100644 index 0000000..a25433d --- /dev/null +++ b/fairseq/modules/lightconv_layer/cuda_function_gen.py @@ -0,0 +1,289 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +def gen_forward(): + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] + + head = """ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "lightconv_cuda.cuh" + +std::vector lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l) { + + at::DeviceGuard g(input.device()); + const auto minibatch = input.size(0); + const auto numFeatures = input.size(1); + const auto sequenceLength = input.size(2); + + const auto numHeads = filters.size(0); + const auto filterSize = filters.size(1); + + const auto numFiltersInBlock = numFeatures / numHeads; + + const dim3 blocks(minibatch, numFeatures); + + auto output = at::zeros_like(input); + auto stream = at::cuda::getCurrentCUDAStream(); +""" + + sequence_if = """ + if (sequenceLength <= {seq}) {{ + switch(filterSize) {{ +""" + + case_k = """ + case {k}: +""" + + main_block = """ + if (padding_l == {pad}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_forward", ([&] {{ + lightconv_forward_kernel<{k}, {b_size}, {pad}, scalar_t> + <<>>( + input.data(), + filters.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + output.data()); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping forward pass" << std::endl; + } + break; +""" + + bad_filter = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping forward pass" << std::endl; + } +""" + + con_else = """ + } else +""" + + final_else = """ + { + switch(filterSize) { +""" + + final_return = """ + } + + return {output}; +} +""" + + with open("lightconv_cuda_forward.cu", "w") as forward: + forward.write(head) + for seq in seqs: + forward.write(sequence_if.format(seq=seq)) + for k in kernels: + forward.write(case_k.format(k=k)) + for pad in [k // 2, k - 1]: + forward.write(main_block.format(k=k, b_size=seq, pad=pad)) + forward.write(bad_padding) + forward.write(bad_filter) + forward.write(con_else) + + forward.write(final_else) + for k in kernels: + forward.write(case_k.format(k=k)) + for pad in [k // 2, k - 1]: + forward.write(main_block.format(k=k, b_size=seq, pad=pad)) + forward.write(bad_padding) + forward.write(bad_filter) + forward.write(final_return) + + +def gen_backward(): + + head = """ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "lightconv_cuda.cuh" + +std::vector lightconv_cuda_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters) { + + // gradWrtInput + const int minibatch = input.size(0); + const int numFeatures = input.size(1); + const int sequenceLength = input.size(2); + + const int numHeads = filters.size(0); + const int filterSize = filters.size(1); + + const dim3 gradBlocks(minibatch, numFeatures); + const dim3 weightGradFirstpassShortBlocks(minibatch, numHeads); + const dim3 weightGradSecondpassBlocks(numHeads, filterSize); + + const int numFiltersInBlock = numFeatures / numHeads; + + auto gradInput = at::zeros_like(input); + auto gradFilters = at::zeros_like(filters); + + at::DeviceGuard g(input.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + switch(filterSize) { +""" + + sequence_if = """ + if (sequenceLength <= {seq}) {{ +""" + + case_k = """ + case {k}: +""" + + main_block = """ + if (padding_l == {p}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_backward", ([&] {{ + lightconv_grad_wrt_input_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + gradOutput.data(), + filters.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + gradInput.data()); + +""" + + weight_grad_short = """ + at::Tensor tempSumGradFilters = at::zeros({{minibatch, numHeads, filterSize}}, input.options().dtype(at::kFloat)); + lightconv_grad_wrt_weights_firstpass_short_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + input.data(), + gradOutput.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + numHeads, + tempSumGradFilters.data() + ); + + lightconv_grad_wrt_weights_secondpass_short_kernel<{k}, {b_size}, scalar_t> + <<>>( + tempSumGradFilters.data(), + minibatch, + numFiltersInBlock, + gradFilters.data() + ); + }})); + }} else +""" + + weight_grad = """ + at::Tensor tempSumGradFilters = at::zeros({{minibatch, numFeatures, filterSize}}, input.options().dtype(at::kFloat)); + lightconv_grad_wrt_weights_firstpass_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + input.data(), + gradOutput.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + tempSumGradFilters.data() + ); + + lightconv_grad_wrt_weights_secondpass_kernel<{k}, {b_size}, scalar_t> + <<>>( + tempSumGradFilters.data(), + minibatch, + numFiltersInBlock, + gradFilters.data() + ); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping backward pass" << std::endl; + } +""" + + breakout = """ + break; +""" + + bad_filter = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping backward pass" << std::endl; +""" + + con_else = """ + } else +""" + + final_else = """ + { + switch(filterSize) { +""" + + last_return = """ + } + return {gradInput, gradFilters}; +} +""" + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] + thresh = [32, 32, 64, 128, 256, -1, -1, -1] + max_mem = [-1, -1, -1, -1, -1, 192, 96, 64] + + with open("lightconv_cuda_backward.cu", "w") as backward: + backward.write(head) + for (k, t, mem) in zip(kernels, thresh, max_mem): + backward.write(case_k.format(k=k)) + for seq in seqs: + if (t == -1 or seq <= t) and (mem == -1 or seq < mem): + backward.write(sequence_if.format(seq=seq)) + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=seq, p=p)) + backward.write(weight_grad_short.format(k=k, b_size=seq, p=p)) + backward.write(bad_padding) + else: + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=32, p=p)) + backward.write(weight_grad.format(k=k, b_size=32, p=p)) + backward.write(bad_padding) + backward.write(breakout) + break + backward.write(con_else) + backward.write(bad_filter) + backward.write(last_return) + + +if __name__ == "__main__": + gen_forward() + gen_backward() diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda.cpp b/fairseq/modules/lightconv_layer/lightconv_cuda.cpp new file mode 100644 index 0000000..4bf6b5a --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_cuda.cpp @@ -0,0 +1,54 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +std::vector lightconv_cuda_forward( + at::Tensor input, + at::Tensor filters, + int padding_l); + +std::vector lightconv_cuda_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters); + + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector lightconv_forward( + at::Tensor input, + at::Tensor filters, + int padding_l) { + + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return lightconv_cuda_forward(input, filters, padding_l); +} + +std::vector lightconv_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters) { + + CHECK_INPUT(gradOutput); + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return lightconv_cuda_backward(gradOutput, padding_l, input, filters); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); + m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); +} diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda.cuh b/fairseq/modules/lightconv_layer/lightconv_cuda.cuh new file mode 100644 index 0000000..3cae57b --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_cuda.cuh @@ -0,0 +1,83 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#define SHFL_MASK 0xffffffff + +template +__global__ +void lightconv_forward_kernel(const scalar_t* input, + const scalar_t* filters, + int minibatch, int sequenceLength, + int numFeatures, int numFiltersInBlock, + scalar_t* output); + +template +__global__ +void lightconv_grad_wrt_input_kernel( + const scalar_t* input, + const scalar_t* filters, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + scalar_t* output); + +template +__global__ +void lightconv_grad_wrt_weights_firstpass_short_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + float* output); + +template +__global__ +void lightconv_grad_wrt_weights_secondpass_short_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output); + +template +__global__ +void lightconv_grad_wrt_weights_firstpass_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + float* output); + +template +__global__ +void lightconv_grad_wrt_weights_secondpass_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output); + diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu b/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu new file mode 100644 index 0000000..8ee83a5 --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu @@ -0,0 +1,375 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "lightconv_cuda.cuh" +#include "lightconv_cuda_forward.cu" +#include "lightconv_cuda_backward.cu" +#include "../cuda_utils.cu" + +template +__global__ +void lightconv_forward_kernel(const scalar_t* input, + const scalar_t* filters, + int minibatch, int sequenceLength, + int numFeatures, int numFiltersInBlock, + scalar_t* output) { + + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int filterIdx = featureIdx / numFiltersInBlock; + + const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + scalar_t* outputFeature = &output[IOOffset]; + const scalar_t* inputFilter = &filters[filterIdx * FS]; + + assert(blockDim.x == SB); + + scalar_t filter[FS]; + #pragma unroll + for (int i = 0; i < FS; ++i) { + filter[i] = inputFilter[i]; + } + + __shared__ scalar_t temp[SB + FS]; + zeroSharedMem(temp); + + const int numIterations = divUp(sequenceLength, SB); + + for (int i = 0; i < numIterations; ++i) { + // Read input into shared memory + const int inputOffset = i * SB; + + load_input_to_shared(inputFeature, inputOffset, sequenceLength, + i, numIterations, (numIterations == 1), temp); + + __syncthreads(); + + scalar_t out = 0; + #pragma unroll + for (int j = 0; j < FS; ++j) { + out += filter[j] * temp[tid + j]; + } + + // Write output + const int outputOffset = inputOffset; + if ((outputOffset + tid) < sequenceLength) { + outputFeature[outputOffset + tid] = out; + } + + __syncthreads(); + } +} + +template +__global__ +void lightconv_grad_wrt_input_kernel( + const scalar_t* input, + const scalar_t* filters, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + scalar_t* output) { + + // input grad kernel is similar to forward kernel + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int filterIdx = featureIdx / numFiltersInBlock; + + const int IOOffset = numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + scalar_t* outputFeature = &output[IOOffset]; + const scalar_t* inputFilter = &filters[filterIdx * FS]; + + assert(blockDim.x == SB); + + scalar_t filter[FS]; + + // The only change is loading the filter in reverse + #pragma unroll + for (int i = 0; i < FS; ++i) { + filter[i] = inputFilter[FS - i - 1]; + } + + __shared__ scalar_t temp[SB + FS]; + const int padding = FS - padding_l - 1; + zeroSharedMem(temp); + + __syncthreads(); + + const int numIterations = divUp(sequenceLength, SB); + + for (int i = 0; i < numIterations; ++i) { + // Read input into shared memory + const int inputOffset = i * SB; + + load_input_to_shared(inputFeature, inputOffset, sequenceLength, + i, numIterations, false, temp); + + __syncthreads(); + + scalar_t out = 0; + #pragma unroll + for (int j = 0; j < FS; ++j) { + out += filter[j] * temp[tid + j]; + } + + // Write output + const int outputOffset = inputOffset; + if ((outputOffset + tid) < sequenceLength) { + outputFeature[outputOffset + tid] = out; + } + + __syncthreads(); + } +} + +// This is by far the most expensive kernel in terms of time taken. +// Can be 16x slower than the forward or grad_wrt_input when filter size is 31 +template +__global__ +void lightconv_grad_wrt_weights_firstpass_short_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + float* output) { + + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int filterIdx = blockIdx.y; + + const int numIterations = divUp(sequenceLength, SB); + + float* tempOutputGradWeight = &output[filterIdx * FS * minibatch]; + + assert(blockDim.x == SB); + + __shared__ scalar_t tempInput[SB + FS]; + __shared__ scalar_t tempGradInput[SB + FS]; + + // local weight accumulation + float accumWeights[FS]; + + // Initialize memory + for (int i = 0; i < FS; ++i) { + accumWeights[i] = float(0.0); + } + + + // loop over each sequence within filterblock + for (int idxInFilterBlock = 0; idxInFilterBlock < numFiltersInBlock; ++idxInFilterBlock) { + + const int featureOffset = batchIdx * numFeatures * sequenceLength + (filterIdx * numFiltersInBlock + idxInFilterBlock) * sequenceLength; + const scalar_t* inputFeature = &input[featureOffset]; + const scalar_t* gradInputFeature = &gradInput[featureOffset]; + + zeroSharedMem(tempInput); + zeroSharedMem(tempGradInput); + __syncthreads(); + + for (int i = 0; i < numIterations; ++i) { + + const int inputOffset = i * SB; + + load_input_to_shared(inputFeature, inputOffset, sequenceLength, + i, numIterations, false, tempInput); + load_input_to_shared(gradInputFeature, inputOffset, sequenceLength, + i, numIterations, false, tempGradInput); + + __syncthreads(); + + const int gradIndex = (FS/2) + tid; + scalar_t tempGrad = tempGradInput[gradIndex]; + + #pragma unroll + for (int j = 0; j < FS; j++) { + const int inputIndex = tid + j; + accumWeights[j] += tempInput[inputIndex] * tempGrad; + } + + __syncthreads(); + + } + + } + + // Row-major sum + for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) { + + float temp; + if (tid < sequenceLength) { + temp = accumWeights[filterWeightIdx]; + } else { + temp = float(0.0); + } + + const int outputOffset = filterWeightIdx * minibatch + batchIdx; + + temp = blockReduce(temp); + + if (tid == 0) { + tempOutputGradWeight[outputOffset] = temp; + } + } +} + +template +__global__ +void lightconv_grad_wrt_weights_secondpass_short_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output) { + + assert(blockDim.x == SB); + + const int tid = threadIdx.x; + + const int filterIdx = blockIdx.x; + const int filterWeightIdx = blockIdx.y; + + const int inputOffset = filterIdx * FS * minibatch + + filterWeightIdx * minibatch; + const float* tempInput = &input[inputOffset]; + + // read into shared memory for reduction + int readIndex = tid; + + float sum = 0.0; + while (readIndex < minibatch) { + sum += tempInput[readIndex]; + readIndex += SB; + } + + float temp = blockReduce(sum); + + if (tid == 0) { + output[blockIdx.x * FS + blockIdx.y] = temp; + } +} + +// This is by far the most expensive kernel in terms of time taken. +// Can be 16x slower than the forward or grad_wrt_input when filter size is 31 +template +__global__ +void lightconv_grad_wrt_weights_firstpass_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + float* output) { + + assert(blockDim.x == SB); + + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int filterIdx = featureIdx / numFiltersInBlock; + const int idxInFilterBlock = featureIdx % numFiltersInBlock; + + const int numIterations = divUp(sequenceLength, SB); + + float temp; + + __shared__ scalar_t tempInput[SB + FS]; + __shared__ scalar_t tempGradInput[SB + FS]; + zeroSharedMem(tempInput); + zeroSharedMem(tempGradInput); + __syncthreads(); + + float accumWeights[FS]; + + for (int i = 0; i < FS; ++i) { + accumWeights[i] = float(0.0); + } + + const int IOOffset = batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + const scalar_t* gradInputFeature = &gradInput[IOOffset]; + float* tempOutputGradWeight = &output[filterIdx * FS * minibatch * numFiltersInBlock]; + + for (int i = 0; i < numIterations; ++i) { + const int inputOffset = i * SB; + + load_input_to_shared(inputFeature, inputOffset, sequenceLength, + i, numIterations, false, tempInput); + load_input_to_shared(gradInputFeature, inputOffset, sequenceLength, + i, numIterations, false, tempGradInput); + __syncthreads(); + + #pragma unroll + for (int j = 0; j < FS; ++j) { + accumWeights[j] += tempInput[tid + j] * tempGradInput[tid + (FS/2)]; + } + + __syncthreads(); + } + + // Row-major sum + for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) { + + // Write to shared memory before reduction + if (tid < sequenceLength) { + temp = accumWeights[filterWeightIdx]; + } else { + temp = float(0.0); + } + + temp = blockReduce(temp); + + const int outputOffset = filterWeightIdx * minibatch * numFiltersInBlock + + batchIdx * numFiltersInBlock + + idxInFilterBlock; + + if (tid == 0) { + tempOutputGradWeight[outputOffset] = temp; + } + } +} + +template +__global__ +void lightconv_grad_wrt_weights_secondpass_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output) { + + assert(blockDim.x == SB); + const int tid = threadIdx.x; + + // What is the id within a minibatch + const int filterIdx = blockIdx.x; + const int filterWeightIdx = blockIdx.y; + + const int inputOffset = filterIdx * FS * minibatch * numFiltersInBlock + + filterWeightIdx * minibatch * numFiltersInBlock; + const float* tempInput = &input[inputOffset]; + + int readIndex = tid; + + float sum = float(0.0); + while (readIndex < (minibatch * numFiltersInBlock)) { + sum += tempInput[readIndex]; + readIndex += SB; + } + + float temp = blockReduce(sum); + + if (tid == 0) { + output[blockIdx.x * FS + blockIdx.y] = temp; + } +} diff --git a/fairseq/modules/lightconv_layer/lightconv_layer.py b/fairseq/modules/lightconv_layer/lightconv_layer.py new file mode 100644 index 0000000..e7e597f --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_layer.py @@ -0,0 +1,137 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import lightconv_cuda +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout +from torch import nn +from torch.autograd import Function + + +class lightconvFunction(Function): + @staticmethod + def forward(ctx, x, weights, padding_l): + ctx.padding_l = padding_l + outputs = lightconv_cuda.forward(x, weights, padding_l) + variables = [x, weights] + ctx.save_for_backward(*variables) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output): + outputs = lightconv_cuda.backward( + grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors + ) + grad_input, grad_weights = outputs + return grad_input, grad_weights, None + + +@with_incremental_state +class LightconvLayer(nn.Module): + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + weight_softmax=False, + num_heads=1, + weight_dropout=0.0, + bias=False, + ): + super(LightconvLayer, self).__init__() + self.input_size = input_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_softmax = weight_softmax + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) + + self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.bias = None + self.reset_parameters() + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + for k, v in state_dict.items(): + if k.endswith(prefix + "weight"): + if v.dim() == 3 and v.size(1) == 1: + state_dict[k] = v.squeeze(1) + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + + def forward(self, x, incremental_state=None): + + # during inference time, incremental BMM is faster + if incremental_state is not None: + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) + + weight = self.weight + if self.weight_softmax: + weight = F.softmax(weight.float(), dim=1).type_as(weight) + + weight = weight[:, -x_unfold.size(2) :] + + K = weight.size(1) + + weight = ( + weight.view(1, H, K) + .expand(T * B, H, K) + .contiguous() + .view(T * B * H, K, 1) + ) + + weight = self.weight_dropout_module(weight) + output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + # during training time, use CUDA kernel + else: + x = x.permute(1, 2, 0).contiguous() + weight = self.weight + if self.weight_softmax: + weight = F.softmax(self.weight, -1) + if self.weight_dropout_module.p: + weight = self.weight_dropout_module(weight) + return lightconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, "input_buffer") + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) + + def half(self): + return self._apply(lambda t: t.half() if t.is_floating_point() else t) diff --git a/fairseq/modules/lightconv_layer/setup.py b/fairseq/modules/lightconv_layer/setup.py new file mode 100644 index 0000000..052635b --- /dev/null +++ b/fairseq/modules/lightconv_layer/setup.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name="lightconv_layer", + ext_modules=[ + CUDAExtension( + "lightconv_cuda", + [ + "lightconv_cuda.cpp", + "lightconv_cuda_kernel.cu", + ], + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/fairseq/modules/lightweight_convolution.py b/fairseq/modules/lightweight_convolution.py new file mode 100644 index 0000000..ec11a95 --- /dev/null +++ b/fairseq/modules/lightweight_convolution.py @@ -0,0 +1,310 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.unfold import unfold1d + + +def LightweightConv( + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + bias=False, +): + if torch.cuda.is_available(): + try: + from fairseq.modules.lightconv_layer import LightconvLayer + + return LightconvLayer( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + bias=bias, + ) + except ImportError as e: + print(e) + return LightweightConv1dTBC( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + bias=bias, + ) + + +class LightweightConv1d(nn.Module): + """Lightweight Convolution assuming the input is BxCxT + This is just an example that explains LightConv clearer than the TBC version. + We don't use this module in the model. + + Args: + input_size: # of channels of the input and output + kernel_size: convolution channels + padding: padding + num_heads: number of heads used. The weight is of shape + `(num_heads, 1, kernel_size)` + weight_softmax: normalize the weight with softmax before the convolution + + Shape: + Input: BxCxT, i.e. (batch_size, input_size, timesteps) + Output: BxCxT, i.e. (batch_size, input_size, timesteps) + + Attributes: + weight: the learnable weights of the module of shape + `(num_heads, 1, kernel_size)` + bias: the learnable bias of the module of shape `(input_size)` + """ + + def __init__( + self, + input_size, + kernel_size=1, + padding=0, + num_heads=1, + weight_softmax=False, + bias=False, + weight_dropout=0.0, + ): + super().__init__() + self.input_size = input_size + self.kernel_size = kernel_size + self.num_heads = num_heads + self.padding = padding + self.weight_softmax = weight_softmax + self.weight = nn.Parameter(torch.Tensor(num_heads, 1, kernel_size)) + + if bias: + self.bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.bias = None + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + + def forward(self, input): + """ + input size: B x C x T + output size: B x C x T + """ + B, C, T = input.size() + H = self.num_heads + + weight = self.weight + if self.weight_softmax: + weight = F.softmax(weight, dim=-1) + + weight = self.weight_dropout_module(weight) + # Merge every C/H entries into the batch dimension (C = self.input_size) + # B x C x T -> (B * C/H) x H x T + # One can also expand the weight to C x 1 x K by a factor of C/H + # and do not reshape the input instead, which is slow though + input = input.view(-1, H, T) + output = F.conv1d(input, weight, padding=self.padding, groups=self.num_heads) + output = output.view(B, C, T) + if self.bias is not None: + output = output + self.bias.view(1, -1, 1) + + return output + + +@with_incremental_state +class LightweightConv1dTBC(nn.Module): + """Lightweight Convolution assuming the input is TxBxC + Args: + input_size: # of channels of the input + kernel_size: convolution channels + padding_l: padding to the left when using "same" padding + num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) + weight_dropout: the drop rate of the DropConnect to drop the weight + weight_softmax: normalize the weight with softmax before the convolution + bias: use bias + + Shape: + Input: TxBxC, i.e. (timesteps, batch_size, input_size) + Output: TxBxC, i.e. (timesteps, batch_size, input_size) + + Attributes: + weight: the learnable weights of the module of shape + `(num_heads, 1, kernel_size)` + bias: the learnable bias of the module of shape `(input_size)` + """ + + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + bias=False, + ): + super().__init__() + self.input_size = input_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) + self.weight_softmax = weight_softmax + + self.weight = nn.Parameter(torch.Tensor(num_heads, 1, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.bias = None + + self.reset_parameters() + self.onnx_trace = False + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + + def forward(self, x, incremental_state=None, unfold=False): + """Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C + args: + x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) + incremental_state: A dict to keep the state + unfold: unfold the input or not. If not, we use the matrix trick instead + """ + unfold = unfold or (incremental_state is not None) + + if unfold: + output = self._forward_unfolded(x, incremental_state) + else: + output = self._forward_expanded(x, incremental_state) + + if self.bias is not None: + output = output + self.bias.view(1, 1, -1) + return output + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def _forward_unfolded(self, x, incremental_state): + """The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.""" + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + weight = self.weight.view(H, K) + if incremental_state is not None: + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) + else: + # unfold the input: T x B x C --> T' x B x C x K + x_unfold = unfold1d(x, self.kernel_size, self.padding_l, 0) + x_unfold = x_unfold.view(T * B * H, R, K) + + if self.weight_softmax: + weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as( + weight + ) + + if incremental_state is not None: + weight = weight[:, -x_unfold.size(2) :] + K = weight.size(1) + + weight = ( + weight.view(1, H, K).expand(T * B, H, K).contiguous().view(T * B * H, K, 1) + ) + + weight = self.weight_dropout_module(weight) + output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + def _forward_expanded(self, x, incremental_state): + """Turn the convolution filters into band matrices and do matrix multiplication. + This is faster when the sequence is short, but less memory efficient. + This is not used in the decoder during inference. + """ + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + weight = self.weight.view(H, K) + if self.weight_softmax: + weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as( + weight + ) + weight = weight.view(1, H, K).expand(T * B, H, K).contiguous() + weight = weight.view(T, B * H, K).transpose(0, 1) + + x = x.view(T, B * H, R).transpose(0, 1) + P = self.padding_l + if K > T and P == K - 1: + weight = weight.narrow(2, K - T, T) + K, P = T, T - 1 + # turn the convolution filters into band matrices + weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) + weight_expanded.as_strided((B * H, T, K), (T * (T + K - 1), T + K, 1)).copy_( + weight + ) + weight_expanded = weight_expanded.narrow(2, P, T) + weight_expanded = self.weight_dropout_module(weight_expanded) + + output = torch.bmm(weight_expanded, x) + output = output.transpose(0, 1).contiguous().view(T, B, C) + return output + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, "input_buffer") + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) + + def extra_repr(self): + s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, bias={}".format( + self.input_size, + self.kernel_size, + self.padding_l, + self.num_heads, + self.weight_softmax, + self.bias is not None, + ) + if self.weight_dropout_module.p > 0.0: + s += ", weight_dropout={}".format(self.weight_dropout_module.p) + return s diff --git a/fairseq/modules/linearized_convolution.py b/fairseq/modules/linearized_convolution.py new file mode 100644 index 0000000..09a8f20 --- /dev/null +++ b/fairseq/modules/linearized_convolution.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state + +from .conv_tbc import ConvTBC + + +@with_incremental_state +class LinearizedConvolution(ConvTBC): + """An optimized version of nn.Conv1d. + + At training time, this module uses ConvTBC, which is an optimized version + of Conv1d. At inference time, it optimizes incremental generation (i.e., + one time step at a time) by replacing the convolutions with linear layers. + Note that the input order changes from training to inference. + """ + + def __init__(self, in_channels, out_channels, kernel_size, **kwargs): + super().__init__(in_channels, out_channels, kernel_size, **kwargs) + self._linearized_weight = None + self.register_backward_hook(self._clear_linearized_weight) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + state = ConvTBC.state_dict(self, destination, prefix, keep_vars=keep_vars) + # don't store redundant _linearized_weight in checkpoints + if prefix + "_linearized_weight" in state: + del state[prefix + "_linearized_weight"] + return state + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + if prefix + "_linearized_weight" in state_dict: + del state_dict[prefix + "_linearized_weight"] + + def forward(self, input, incremental_state=None): + """ + Args: + incremental_state: Used to buffer signal; if not None, then input is + expected to contain a single frame. If the input order changes + between time steps, call reorder_incremental_state. + Input: + Time x Batch x Channel during training + Batch x Time x Channel during inference + """ + if incremental_state is None: + output = super().forward(input) + if self.kernel_size[0] > 1 and self.padding[0] > 0: + # remove future timesteps added by padding + output = output[: -self.padding[0], :, :] + return output + + # reshape weight + weight = self._get_linearized_weight() + kw = self.kernel_size[0] + + bsz = input.size(0) # input: bsz x len x dim + if kw > 1: + input = input.data + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = input.new(bsz, kw, input.size(2)).zero_() + self._set_input_buffer(incremental_state, input_buffer) + else: + # shift buffer + input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone() + # append next input + input_buffer[:, -1, :] = input[:, -1, :] + input = input_buffer + with torch.no_grad(): + output = F.linear(input.view(bsz, -1), weight, self.bias) + return output.view(bsz, 1, -1) + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(0, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, "input_buffer") + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) + + def _get_linearized_weight(self): + if self._linearized_weight is None: + kw = self.kernel_size[0] + weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() + assert weight.size() == (self.out_channels, kw, self.in_channels) + self._linearized_weight = torch.nn.Parameter( + weight.view(self.out_channels, -1) + ) + return self._linearized_weight + + def _clear_linearized_weight(self, *args): + self._linearized_weight = None diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py new file mode 100644 index 0000000..99f95de --- /dev/null +++ b/fairseq/modules/multihead_attention.py @@ -0,0 +1,488 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise +from torch import Tensor, nn +from torch.nn import Parameter + + +@with_incremental_state +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.onnx_trace = False + self.tpu = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def prepare_for_tpu_(self, **kwargs): + self.tpu = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + if ( + not self.onnx_trace + and not self.tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + ): + assert key is not None and value is not None + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not self.tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = utils.softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + elif key_padding_mask is not None: + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention and input_buffer_k.size( + 0 + ) == new_order.size(0): + break + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + + keys_to_remove.append(k) + + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ + dim : 2 * dim + ] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value diff --git a/fairseq/modules/positional_embedding.py b/fairseq/modules/positional_embedding.py new file mode 100644 index 0000000..8e94e35 --- /dev/null +++ b/fairseq/modules/positional_embedding.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from .learned_positional_embedding import LearnedPositionalEmbedding +from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding + + +def PositionalEmbedding( + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + learned: bool = False, +): + if learned: + # if padding_idx is specified then offset the embedding ids by + # this index and adjust num_embeddings appropriately + # TODO: The right place for this offset would be inside + # LearnedPositionalEmbedding. Move this there for a cleaner implementation. + if padding_idx is not None: + num_embeddings = num_embeddings + padding_idx + 1 + m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + if padding_idx is not None: + nn.init.constant_(m.weight[padding_idx], 0) + else: + m = SinusoidalPositionalEmbedding( + embedding_dim, + padding_idx, + init_size=num_embeddings + padding_idx + 1, + ) + return m diff --git a/fairseq/modules/quant_noise.py b/fairseq/modules/quant_noise.py new file mode 100644 index 0000000..d777dfb --- /dev/null +++ b/fairseq/modules/quant_noise.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module diff --git a/fairseq/modules/quantization/__init__.py b/fairseq/modules/quantization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fairseq/modules/quantization/pq/__init__.py b/fairseq/modules/quantization/pq/__init__.py new file mode 100644 index 0000000..5b10b51 --- /dev/null +++ b/fairseq/modules/quantization/pq/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .utils import SizeTracker, quantize_model_ # NOQA diff --git a/fairseq/modules/quantization/pq/em.py b/fairseq/modules/quantization/pq/em.py new file mode 100644 index 0000000..6f15c3e --- /dev/null +++ b/fairseq/modules/quantization/pq/em.py @@ -0,0 +1,211 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import random +from collections import Counter + +import torch + + +class EM: + """ + EM algorithm used to quantize the columns of W to minimize + + ||W - W_hat||^2 + + Args: + - W: weight matrix of size (in_features x out_features) + - n_iter: number of k-means iterations + - n_centroids: number of centroids (size of codebook) + - eps: for cluster reassignment when an empty cluster is found + - max_tentatives for cluster reassignment when an empty cluster is found + - verbose: print error after each iteration + + Remarks: + - If one cluster is empty, the most populated cluster is split into + two clusters + - All the relevant dimensions are specified in the code + """ + + def __init__( + self, W, n_centroids=256, n_iter=20, eps=1e-6, max_tentatives=30, verbose=True + ): + self.W = W + self.n_centroids = n_centroids + self.n_iter = n_iter + self.eps = eps + self.max_tentatives = max_tentatives + self.verbose = verbose + self.centroids = torch.Tensor() + self.assignments = torch.Tensor() + self.objective = [] + + def initialize_centroids(self): + """ + Initializes the centroids by sampling random columns from W. + """ + + in_features, out_features = self.W.size() + indices = torch.randint( + low=0, high=out_features, size=(self.n_centroids,) + ).long() + self.centroids = self.W[:, indices].t() # (n_centroids x in_features) + + def step(self, i): + """ + There are two standard steps for each iteration: expectation (E) and + minimization (M). The E-step (assignment) is performed with an exhaustive + search and the M-step (centroid computation) is performed with + the exact solution. + + Args: + - i: step number + + Remarks: + - The E-step heavily uses PyTorch broadcasting to speed up computations + and reduce the memory overhead + """ + + # assignments (E-step) + distances = self.compute_distances() # (n_centroids x out_features) + self.assignments = torch.argmin(distances, dim=0) # (out_features) + n_empty_clusters = self.resolve_empty_clusters() + + # centroids (M-step) + for k in range(self.n_centroids): + W_k = self.W[:, self.assignments == k] # (in_features x size_of_cluster_k) + self.centroids[k] = W_k.mean(dim=1) # (in_features) + + # book-keeping + obj = (self.centroids[self.assignments].t() - self.W).norm(p=2).item() + self.objective.append(obj) + if self.verbose: + logging.info( + f"Iteration: {i},\t" + f"objective: {obj:.6f},\t" + f"resolved empty clusters: {n_empty_clusters}" + ) + + def resolve_empty_clusters(self): + """ + If one cluster is empty, the most populated cluster is split into + two clusters by shifting the respective centroids. This is done + iteratively for a fixed number of tentatives. + """ + + # empty clusters + counts = Counter(map(lambda x: x.item(), self.assignments)) + empty_clusters = set(range(self.n_centroids)) - set(counts.keys()) + n_empty_clusters = len(empty_clusters) + + tentatives = 0 + while len(empty_clusters) > 0: + # given an empty cluster, find most populated cluster and split it into two + k = random.choice(list(empty_clusters)) + m = counts.most_common(1)[0][0] + e = torch.randn_like(self.centroids[m]) * self.eps + self.centroids[k] = self.centroids[m].clone() + self.centroids[k] += e + self.centroids[m] -= e + + # recompute assignments + distances = self.compute_distances() # (n_centroids x out_features) + self.assignments = torch.argmin(distances, dim=0) # (out_features) + + # check for empty clusters + counts = Counter(map(lambda x: x.item(), self.assignments)) + empty_clusters = set(range(self.n_centroids)) - set(counts.keys()) + + # increment tentatives + if tentatives == self.max_tentatives: + logging.info( + f"Could not resolve all empty clusters, {len(empty_clusters)} remaining" + ) + raise EmptyClusterResolveError + tentatives += 1 + + return n_empty_clusters + + def compute_distances(self): + """ + For every centroid m, computes + + ||M - m[None, :]||_2 + + Remarks: + - We rely on PyTorch's broadcasting to speed up computations + and reduce the memory overhead + - Without chunking, the sizes in the broadcasting are modified as: + (n_centroids x n_samples x out_features) -> (n_centroids x out_features) + - The broadcasting computation is automatically chunked so that + the tensors fit into the memory of the GPU + """ + + nb_centroids_chunks = 1 + + while True: + try: + return torch.cat( + [ + (self.W[None, :, :] - centroids_c[:, :, None]).norm(p=2, dim=1) + for centroids_c in self.centroids.chunk( + nb_centroids_chunks, dim=0 + ) + ], + dim=0, + ) + except RuntimeError: + nb_centroids_chunks *= 2 + + def assign(self): + """ + Assigns each column of W to its closest centroid, thus essentially + performing the E-step in train(). + + Remarks: + - The function must be called after train() or after loading + centroids using self.load(), otherwise it will return empty tensors + """ + + distances = self.compute_distances() # (n_centroids x out_features) + self.assignments = torch.argmin(distances, dim=0) # (out_features) + + def save(self, path, layer): + """ + Saves centroids and assignments. + + Args: + - path: folder used to save centroids and assignments + """ + + torch.save(self.centroids, os.path.join(path, "{}_centroids.pth".format(layer))) + torch.save( + self.assignments, os.path.join(path, "{}_assignments.pth".format(layer)) + ) + torch.save(self.objective, os.path.join(path, "{}_objective.pth".format(layer))) + + def load(self, path, layer): + """ + Loads centroids and assignments from a given path + + Args: + - path: folder use to load centroids and assignments + """ + + self.centroids = torch.load( + os.path.join(path, "{}_centroids.pth".format(layer)) + ) + self.assignments = torch.load( + os.path.join(path, "{}_assignments.pth".format(layer)) + ) + self.objective = torch.load( + os.path.join(path, "{}_objective.pth".format(layer)) + ) + + +class EmptyClusterResolveError(Exception): + pass diff --git a/fairseq/modules/quantization/pq/modules/__init__.py b/fairseq/modules/quantization/pq/modules/__init__.py new file mode 100644 index 0000000..b67c8e8 --- /dev/null +++ b/fairseq/modules/quantization/pq/modules/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .qconv import PQConv2d # NOQA +from .qemb import PQEmbedding # NOQA +from .qlinear import PQLinear # NOQA diff --git a/fairseq/modules/quantization/pq/modules/qconv.py b/fairseq/modules/quantization/pq/modules/qconv.py new file mode 100644 index 0000000..d15ec19 --- /dev/null +++ b/fairseq/modules/quantization/pq/modules/qconv.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair + + +class PQConv2d(nn.Module): + """ + Quantized counterpart of nn.Conv2d module. Stores the centroid, the assignments + and the non-quantized biases. The full weight is re-instantiated at each forward + pass and autograd automatically computes the gradients with respect to the + centroids. + + Args: + - centroids: centroids of size n_centroids x block_size + - assignments: assignments of the centroids to the subvectors + of size self.out_channels x n_blocks + - bias: the non-quantized bias, must be either torch.Tensor or None + + Remarks: + - We refer the reader to the official documentation of the nn.Conv2d module + for the other arguments and the behavior of the module. + - Performance tests on GPU show that this implementation is 10% slower than + the non-quantized nn.Conv2d module for a standard training loop. + - During the backward, the gradients are averaged by cluster and not summed. + This explains the hook registered to the centroids. + """ + + def __init__( + self, + centroids, + assignments, + bias, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + ): + super(PQConv2d, self).__init__() + self.block_size = centroids.size(1) + self.n_centroids = centroids.size(0) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.padding_mode = padding_mode + # check compatibility + if in_channels // groups * np.prod(self.kernel_size) % self.block_size != 0: + raise ValueError("Wrong PQ sizes") + if len(assignments) % out_channels != 0: + raise ValueError("Wrong PQ sizes") + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + # define parameters + self.centroids = nn.Parameter(centroids, requires_grad=True) + self.register_buffer("assignments", assignments) + self.register_buffer("counts", torch.bincount(assignments).type_as(centroids)) + if bias is not None: + self.bias = nn.Parameter(bias) + else: + self.register_parameter("bias", None) + # register hook for averaging gradients per centroids instead of summing + self.centroids.register_hook(lambda x: x / self.counts[:, None]) + + @property + def weight(self): + return ( + self.centroids[self.assignments] + .reshape(-1, self.out_channels, self.block_size) + .permute(1, 0, 2) + .reshape( + self.out_channels, self.in_channels // self.groups, *self.kernel_size + ) + ) + + def forward(self, x): + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def extra_repr(self): + s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias is None: + s += ", bias=False" + if self.padding_mode != "zeros": + s += ", padding_mode={padding_mode}" + s += ", n_centroids={n_centroids}, block_size={block_size}" + return s.format(**self.__dict__) diff --git a/fairseq/modules/quantization/pq/modules/qemb.py b/fairseq/modules/quantization/pq/modules/qemb.py new file mode 100644 index 0000000..3a74ad3 --- /dev/null +++ b/fairseq/modules/quantization/pq/modules/qemb.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PQEmbedding(nn.Module): + """ + Quantized counterpart of nn.Embedding module. Stores the centroids and + the assignments. The full weight is re-instantiated at each forward + pass. + + Args: + - centroids: centroids of size n_centroids x block_size + - assignments: assignments of the centroids to the subvectors + of size self.out_features x n_blocks + - bias: the non-quantized bias + + Remarks: + - We refer the reader to the official documentation of the nn.Embedding module + for the other arguments and the behavior of the module + - Performance tests on GPU show that this implementation is 10% slower than + the non-quantized nn.Embedding module for a standard training loop. + """ + + def __init__( + self, + centroids, + assignments, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + ): + super(PQEmbedding, self).__init__() + self.block_size = centroids.size(1) + self.n_centroids = centroids.size(0) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + # check compatibility + if self.embedding_dim % self.block_size != 0: + raise ValueError("Wrong PQ sizes") + if len(assignments) % self.num_embeddings != 0: + raise ValueError("Wrong PQ sizes") + # define parameters + self.centroids = nn.Parameter(centroids, requires_grad=True) + self.register_buffer("assignments", assignments) + self.register_buffer("counts", torch.bincount(assignments).type_as(centroids)) + + @property + def weight(self): + return ( + self.centroids[self.assignments] + .reshape(-1, self.num_embeddings, self.block_size) + .permute(1, 0, 2) + .flatten(1, 2) + ) + + def forward(self, input): + return F.embedding( + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + def extra_repr(self): + s = "{num_embeddings}, {embedding_dim}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.max_norm is not None: + s += ", max_norm={max_norm}" + if self.norm_type != 2: + s += ", norm_type={norm_type}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + if self.sparse is not False: + s += ", sparse=True" + s += ", n_centroids={n_centroids}, block_size={block_size}" + + return s.format(**self.__dict__) diff --git a/fairseq/modules/quantization/pq/modules/qlinear.py b/fairseq/modules/quantization/pq/modules/qlinear.py new file mode 100644 index 0000000..9bdd25a --- /dev/null +++ b/fairseq/modules/quantization/pq/modules/qlinear.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PQLinear(nn.Module): + """ + Quantized counterpart of nn.Linear module. Stores the centroid, the assignments + and the non-quantized biases. The full weight is re-instantiated at each forward + pass. + + Args: + - centroids: centroids of size n_centroids x block_size + - assignments: assignments of the centroids to the subvectors + of size self.out_features x n_blocks + - bias: the non-quantized bias + + Remarks: + - We refer the reader to the official documentation of the nn.Linear module + for the other arguments and the behavior of the module + - Performance tests on GPU show that this implementation is 15% slower than + the non-quantized nn.Linear module for a standard training loop. + """ + + def __init__(self, centroids, assignments, bias, in_features, out_features): + super(PQLinear, self).__init__() + self.block_size = centroids.size(1) + self.n_centroids = centroids.size(0) + self.in_features = in_features + self.out_features = out_features + # check compatibility + if self.in_features % self.block_size != 0: + raise ValueError("Wrong PQ sizes") + if len(assignments) % self.out_features != 0: + raise ValueError("Wrong PQ sizes") + # define parameters + self.centroids = nn.Parameter(centroids, requires_grad=True) + self.register_buffer("assignments", assignments) + self.register_buffer("counts", torch.bincount(assignments).type_as(centroids)) + if bias is not None: + self.bias = nn.Parameter(bias) + else: + self.register_parameter("bias", None) + + @property + def weight(self): + return ( + self.centroids[self.assignments] + .reshape(-1, self.out_features, self.block_size) + .permute(1, 0, 2) + .flatten(1, 2) + ) + + def forward(self, x): + return F.linear( + x, + self.weight, + self.bias, + ) + + def extra_repr(self): + return f"in_features={self.in_features},\ + out_features={self.out_features},\ + n_centroids={self.n_centroids},\ + block_size={self.block_size},\ + bias={self.bias is not None}" diff --git a/fairseq/modules/quantization/pq/pq.py b/fairseq/modules/quantization/pq/pq.py new file mode 100644 index 0000000..eddc2eb --- /dev/null +++ b/fairseq/modules/quantization/pq/pq.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .em import EM, EmptyClusterResolveError + + +class PQ(EM): + """ + Quantizes the layer weights W with the standard Product Quantization + technique. This learns a codebook of codewords or centroids of size + block_size from W. For further reference on using PQ to quantize + neural networks, see "And the Bit Goes Down: Revisiting the Quantization + of Neural Networks", Stock et al., ICLR 2020. + + PQ is performed in two steps: + (1) The matrix W (weights or fully-connected or convolutional layer) + is reshaped to (block_size, -1). + - If W is fully-connected (2D), its columns are split into + blocks of size block_size. + - If W is convolutional (4D), its filters are split along the + spatial dimension. + (2) We apply the standard EM/k-means algorithm to the resulting reshaped matrix. + + Args: + - W: weight matrix to quantize of size (in_features x out_features) + - block_size: size of the blocks (subvectors) + - n_centroids: number of centroids + - n_iter: number of k-means iterations + - eps: for cluster reassignment when an empty cluster is found + - max_tentatives for cluster reassignment when an empty cluster is found + - verbose: print information after each iteration + + Remarks: + - block_size be compatible with the shape of W + """ + + def __init__( + self, + W, + block_size, + n_centroids=256, + n_iter=20, + eps=1e-6, + max_tentatives=30, + verbose=True, + ): + self.block_size = block_size + W_reshaped = self._reshape(W) + super(PQ, self).__init__( + W_reshaped, + n_centroids=n_centroids, + n_iter=n_iter, + eps=eps, + max_tentatives=max_tentatives, + verbose=verbose, + ) + + def _reshape(self, W): + """ + Reshapes the matrix W as expained in step (1). + """ + + # fully connected: by convention the weight has size out_features x in_features + if len(W.size()) == 2: + self.out_features, self.in_features = W.size() + assert ( + self.in_features % self.block_size == 0 + ), "Linear: n_blocks must be a multiple of in_features" + return ( + W.reshape(self.out_features, -1, self.block_size) + .permute(2, 1, 0) + .flatten(1, 2) + ) + + # convolutional: we reshape along the spatial dimension + elif len(W.size()) == 4: + self.out_channels, self.in_channels, self.k_h, self.k_w = W.size() + assert ( + self.in_channels * self.k_h * self.k_w + ) % self.block_size == 0, ( + "Conv2d: n_blocks must be a multiple of in_channels * k_h * k_w" + ) + return ( + W.reshape(self.out_channels, -1, self.block_size) + .permute(2, 1, 0) + .flatten(1, 2) + ) + # not implemented + else: + raise NotImplementedError(W.size()) + + def encode(self): + """ + Performs self.n_iter EM steps. + """ + + self.initialize_centroids() + for i in range(self.n_iter): + try: + self.step(i) + except EmptyClusterResolveError: + break + + def decode(self): + """ + Returns the encoded full weight matrix. Must be called after + the encode function. + """ + + # fully connected case + if "k_h" not in self.__dict__: + return ( + self.centroids[self.assignments] + .reshape(-1, self.out_features, self.block_size) + .permute(1, 0, 2) + .flatten(1, 2) + ) + + # convolutional case + else: + return ( + self.centroids[self.assignments] + .reshape(-1, self.out_channels, self.block_size) + .permute(1, 0, 2) + .reshape(self.out_channels, self.in_channels, self.k_h, self.k_w) + ) diff --git a/fairseq/modules/quantization/pq/utils.py b/fairseq/modules/quantization/pq/utils.py new file mode 100644 index 0000000..03b15e4 --- /dev/null +++ b/fairseq/modules/quantization/pq/utils.py @@ -0,0 +1,337 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import re +from operator import attrgetter, itemgetter + +import numpy as np +import torch.distributed as dist +import torch.nn as nn + +from .modules import PQConv2d, PQEmbedding, PQLinear +from .pq import PQ + + +def quantize_model_( + model, + size_tracker, + layers_to_quantize, + block_sizes_config, + n_centroids_config, + step=0, + n_iter=15, + eps=1e-6, + max_tentatives=100, + verbose=True, +): + """ + Quantize a model in-place by stages. All the targeted + layers are replaced by their quantized counterpart, + and the model is ready for the finetuning of the + centroids in a standard training loop (no modifications + required). Note that we do not quantize biases. + + Args: + - model: a nn.Module + - size_tracker: useful for tracking quatization statistics + - layers_to_quantize: a list containing regexps for + filtering the layers to quantize at each stage according + to their name (as in model.named_parameters()) + - block_sizes_config: dict like + { + 'Conv2d': ('kernel_size', {'(3, 3)': 9, '(1, 1)': 4}), + 'Linear': ('in_features', {'*': 8}) + } + For instance, all conv2d layers with kernel size 3x3 have + a block size of 9 and all Linear layers are quantized with + a block size of 8, irrespective of their size. + - n_centroids_config: dict like + { + 'Conv2d': ('kernel_size', {'*': 256}), + 'Linear': ('in_features', {'*': 256}) + } + For instance, all conv2d layers are quantized with 256 centroids + - step: the layers to quantize inplace corresponding + to layers_to_quantize[step] + """ + + quantized_layers = get_layers(model, layers_to_quantize[step]) + + for layer in quantized_layers: + + # book-keeping + is_master_process = (not dist.is_initialized()) or ( + dist.is_initialized() and dist.get_rank() == 0 + ) + verbose = verbose and is_master_process + + # get block size and centroids + module = attrgetter(layer)(model) + block_size = get_param(module, layer, block_sizes_config) + n_centroids = get_param(module, layer, n_centroids_config) + if verbose: + logging.info( + f"Quantizing layer {layer} with block size {block_size} and {n_centroids} centroids" + ) + + # quantize layer + weight = module.weight.data.clone() + is_bias = "bias" in [x[0] for x in module.named_parameters()] + bias = module.bias.data.clone() if is_bias else None + quantizer = PQ( + weight, + block_size, + n_centroids=n_centroids, + n_iter=n_iter, + eps=eps, + max_tentatives=max_tentatives, + verbose=verbose, + ) + + # quantization performed on all GPUs with same seed + quantizer.encode() + centroids = quantizer.centroids.contiguous() + assignments = quantizer.assignments.contiguous() + + # broadcast results to make sure weights are up-to-date + if dist.is_initialized(): + dist.broadcast(centroids, 0) + dist.broadcast(assignments, 0) + + # instantiate the quantized counterpart + if isinstance(module, nn.Linear): + out_features, in_features = map( + lambda k: module.__dict__[k], ["out_features", "in_features"] + ) + quantized_module = PQLinear( + centroids, assignments, bias, in_features, out_features + ) + elif isinstance(module, nn.Embedding): + num_embeddings, embedding_dim = map( + lambda k: module.__dict__[k], ["num_embeddings", "embedding_dim"] + ) + quantized_module = PQEmbedding( + centroids, assignments, num_embeddings, embedding_dim + ) + elif isinstance(module, nn.Conv2d): + out_channels, in_channels, kernel_size = map( + lambda k: module.__dict__[k], + ["out_channels", "in_channels", "kernel_size"], + ) + stride, padding, dilation, groups, padding_mode = map( + lambda k: module.__dict__[k], + ["stride", "padding", "dilation", "groups", "padding_mode"], + ) + + quantized_module = PQConv2d( + centroids, + assignments, + bias, + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + ) + else: + raise ValueError(f"Module {module} not yet supported for quantization") + + # replace layer by its quantized counterpart + attrsetter(layer)(model, quantized_module) + + # update statistics + size_tracker.update(weight, block_size, n_centroids) + + # return name of quantized layers + return quantized_layers + + +def get_layers(model, filter_regexp): + """ + Filters out the layers according to a regexp. Note that + we omit biases. + + Args: + - model: a nn.Module + - filter_regexp: a regexp to filter the layers to keep + according to their name in model.named_parameters(). + For instance, the regexp: + + down_layers\\.[123456]\\.(conv[12]|identity\\.conv)) + + is keeping blocks down_layers from 1 to 6, and inside + each block is keeping conv1, conv2 and identity.conv. + + Remarks: + - We add (module\\.)? at the beginning of the regexp to + account for the possible use of nn.parallel.DataParallel + """ + + # get all parameter names + all_layers = map(itemgetter(0), model.named_parameters()) + + # remove biases + all_layers = filter(lambda x: "bias" not in x, all_layers) + + # remove .weight in all other names (or .weight_orig is spectral norm) + all_layers = map(lambda x: x.replace(".weight_orig", ""), all_layers) + all_layers = map(lambda x: x.replace(".weight", ""), all_layers) + + # return filtered layers + filter_regexp = "(module\\.)?" + "(" + filter_regexp + ")" + r = re.compile(filter_regexp) + + return list(filter(r.match, all_layers)) + + +def get_param(module, layer_name, param_config): + """ + Given a quantization configuration, get the right parameter + for the module to be quantized. + + Args: + - module: a nn.Module + - layer_name: the name of the layer + - param_config: a dict like + { + 'Conv2d': ('kernel_size', {'(3, 3)': 9, '(1, 1)': 4}), + 'Linear': ('in_features', {'*': 8}) + } + For instance, all conv2d layers with kernel size 3x3 have + a block size of 9 and all Linear layers are quantized with + a block size of 8, irrespective of their size. + + Remarks: + - if 'fuzzy_name' is passed as a parameter, layers whose layer_name + include 'fuzzy_name' will be assigned the given parameter. + In the following example, conv.expand layers will have a block + size of 9 while conv.reduce will have a block size of 4 and all + other layers will have a block size of 2. + { + 'Conv2d': ('fuzzy_name', {'expand': 9, 'reduce': 4, '*': 2}), + 'Linear': ('fuzzy_name', {'classifier': 8, 'projection': 4}) + } + + """ + + layer_type = module.__class__.__name__ + + if layer_type not in param_config: + raise KeyError(f"Layer type {layer_type} not in config for layer {module}") + + feature, params = param_config[module.__class__.__name__] + + if feature != "fuzzy_name": + feature_value = str(getattr(module, feature)) + if feature_value not in params: + if "*" in params: + feature_value = "*" + else: + raise KeyError( + f"{feature}={feature_value} not in config for layer {module}" + ) + else: + feature_values = [name for name in params if name in layer_name] + if len(feature_values) == 0: + if "*" in params: + feature_value = "*" + else: + raise KeyError(f"name={layer_name} not in config for {module}") + else: + feature_value = feature_values[0] + + return params[feature_value] + + +class SizeTracker(object): + """ + Class to keep track of the compressed network size with iPQ. + + Args: + - model: a nn.Module + + Remarks: + - The compressed size is the sum of three components + for each layer in the network: + (1) Storing the centroids given by iPQ in fp16 + (2) Storing the assignments of the blocks in int8 + (3) Storing all non-compressed elements such as biases + - This cost in only valid if we use 256 centroids (then + indexing can indeed by done with int8). + """ + + def __init__(self, model): + self.model = model + self.size_non_compressed_model = self.compute_size() + self.size_non_quantized = self.size_non_compressed_model + self.size_index = 0 + self.size_centroids = 0 + self.n_quantized_layers = 0 + + def compute_size(self): + """ + Computes the size of the model (in MB). + """ + + res = 0 + for _, p in self.model.named_parameters(): + res += p.numel() + return res * 4 / 1024 / 1024 + + def update(self, W, block_size, n_centroids): + """ + Updates the running statistics when quantizing a new layer. + """ + + # bits per weights + bits_per_weight = np.log2(n_centroids) / block_size + self.n_quantized_layers += 1 + + # size of indexing the subvectors of size block_size (in MB) + size_index_layer = bits_per_weight * W.numel() / 8 / 1024 / 1024 + self.size_index += size_index_layer + + # size of the centroids stored in float16 (in MB) + size_centroids_layer = n_centroids * block_size * 2 / 1024 / 1024 + self.size_centroids += size_centroids_layer + + # size of non-compressed layers, e.g. LayerNorms or biases (in MB) + size_uncompressed_layer = W.numel() * 4 / 1024 / 1024 + self.size_non_quantized -= size_uncompressed_layer + + def __repr__(self): + size_compressed = ( + self.size_index + self.size_centroids + self.size_non_quantized + ) + compression_ratio = self.size_non_compressed_model / size_compressed # NOQA + return ( + f"Non-compressed model size: {self.size_non_compressed_model:.2f} MB. " + f"After quantizing {self.n_quantized_layers} layers, size " + f"(indexing + centroids + other): {self.size_index:.2f} MB + " + f"{self.size_centroids:.2f} MB + {self.size_non_quantized:.2f} MB = " + f"{size_compressed:.2f} MB, compression ratio: {compression_ratio:.2f}x" + ) + + +def attrsetter(*items): + def resolve_attr(obj, attr): + attrs = attr.split(".") + head = attrs[:-1] + tail = attrs[-1] + + for name in head: + obj = getattr(obj, name) + return obj, tail + + def g(obj, val): + for attr in items: + resolved_obj, resolved_attr = resolve_attr(obj, attr) + setattr(resolved_obj, resolved_attr, val) + + return g diff --git a/fairseq/modules/quantization/quantization_options.py b/fairseq/modules/quantization/quantization_options.py new file mode 100644 index 0000000..b46d682 --- /dev/null +++ b/fairseq/modules/quantization/quantization_options.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +def parse_config_yaml(yaml_data): + # Initialize to default options. + quantization_options = { + "n_centroids": { + "Linear": ["in_features", {"*": 256}], + "Embedding": ["embedding_dim", {"*": 256}], + }, + "block_sizes": { + "Linear": ["fuzzy_name", {"fc": 8, "attn": 4, "emb": 4}], + "Embedding": ["fuzzy_name", {"emb": 8}], + }, + "layers_to_quantize": [ + "decoder\\.layers\\.\\d+\\.fc[12]", + "decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01]", + "decoder\\.layers\\.\\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj)", + ], + } + + if "n_centroids" in yaml_data: + quantization_options["n_centroids"] = { + layer: convert_yaml_to_tuple(layer_data) + for layer, layer_data in yaml_data["n_centroids"].items() + } + if "block_sizes" in yaml_data: + quantization_options["block_sizes"] = { + layer: convert_yaml_to_tuple(layer_data) + for layer, layer_data in yaml_data["block_sizes"].items() + } + if "layers_to_quantize" in yaml_data: + quantization_options["layers_to_quantize"] = yaml_data["layers_to_quantize"] + + return quantization_options + + +def convert_yaml_to_tuple(yaml_dictionary): + """Converts a yaml dictionary with two keys: `key` and `value` into a two + argument tuple of those values.""" + return (yaml_dictionary["key"], yaml_dictionary["value"]) diff --git a/fairseq/modules/quantization/scalar/__init__.py b/fairseq/modules/quantization/scalar/__init__.py new file mode 100644 index 0000000..143834f --- /dev/null +++ b/fairseq/modules/quantization/scalar/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .utils import quantize_model_ # NOQA diff --git a/fairseq/modules/quantization/scalar/modules/__init__.py b/fairseq/modules/quantization/scalar/modules/__init__.py new file mode 100644 index 0000000..8031d9c --- /dev/null +++ b/fairseq/modules/quantization/scalar/modules/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .qact import ActivationQuantizer # NOQA +from .qconv import IntConv2d # NOQA +from .qemb import IntEmbedding # NOQA +from .qlinear import IntLinear # NOQA diff --git a/fairseq/modules/quantization/scalar/modules/qact.py b/fairseq/modules/quantization/scalar/modules/qact.py new file mode 100644 index 0000000..c5dd1d6 --- /dev/null +++ b/fairseq/modules/quantization/scalar/modules/qact.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from ..ops import emulate_int + + +class ActivationQuantizer: + """ + Fake scalar quantization of the activations using a forward hook. + + Args: + - module. a nn.Module for which we quantize the *post-activations* + - p: proportion of activations to quantize, set by default to 1 + - update_step: to recompute quantization parameters + - bits: number of bits for quantization + - method: choose among {"tensor", "histogram", "channel"} + - clamp_threshold: to prevent gradients overflow + + Remarks: + - Parameters scale and zero_point are recomputed every update_step + forward pass to reduce the overhead + - For the list of quantization methods and number of bits, see ops.py + - To remove the hook from the module, simply call self.handle.remove() + - At test time, the activations are fully quantized + - We use the straight-through estimator so that the gradients + back-propagate nicely in the network, this is implemented with + the detach() trick + - The activations are hard-clamped in [-clamp_threshold, clamp_threshold] + to prevent overflow during the backward pass + """ + + def __init__( + self, + module, + p=1, + update_step=1000, + bits=8, + method="histogram", + clamp_threshold=5, + ): + self.module = module + self.p = p + self.update_step = update_step + self.counter = 0 + self.bits = bits + self.method = method + self.clamp_threshold = clamp_threshold + self.handle = None + self.register_hook() + + def register_hook(self): + # forward hook + def quantize_hook(module, x, y): + + # update parameters every 1000 iterations + if self.counter % self.update_step == 0: + self.scale = None + self.zero_point = None + self.counter += 1 + + # train with QuantNoise and evaluate the fully quantized network + p = self.p if self.module.training else 1 + + # quantize activations + y_q, self.scale, self.zero_point = emulate_int( + y.detach(), + bits=self.bits, + method=self.method, + scale=self.scale, + zero_point=self.zero_point, + ) + + # mask to apply noise + mask = torch.zeros_like(y) + mask.bernoulli_(1 - p) + noise = (y_q - y).masked_fill(mask.bool(), 0) + + # using straight-through estimator (STE) + clamp_low = -self.scale * self.zero_point + clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point) + return torch.clamp(y, clamp_low.item(), clamp_high.item()) + noise.detach() + + # register hook + self.handle = self.module.register_forward_hook(quantize_hook) diff --git a/fairseq/modules/quantization/scalar/modules/qconv.py b/fairseq/modules/quantization/scalar/modules/qconv.py new file mode 100644 index 0000000..83788c6 --- /dev/null +++ b/fairseq/modules/quantization/scalar/modules/qconv.py @@ -0,0 +1,149 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.utils import _pair + +from ..ops import emulate_int + + +class IntConv2d(_ConvNd): + """ + Quantized counterpart of the nn.Conv2d module that applies QuantNoise during training. + + Args: + - standard nn.Conv2d parameters + - p: amount of noise to inject (0 = no quantization, 1 = quantize all the weights) + - bits: number of bits + - method: choose among {"tensor", "histogram", "channel"} + - update_step: recompute scale and zero_point every update_steps iterations + + Remarks: + - We use the straight-thgourh estimator so that the gradients + back-propagate nicely in the network, this is implemented with + the detach() trick + - Parameters scale and zero_point are recomputed every update_step + forward pass to reduce the overhead + - At test time, the weights are fully quantized + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + p=0, + bits=8, + method="histogram", + update_step=1000, + ): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super(IntConv2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + bias, + padding_mode, + ) + + # quantization parameters + self.p = p + self.bits = bits + self.method = method + self.update_step = update_step + self.counter = 0 + + def _conv_forward(self, input, weight): + if self.padding_mode != "zeros": + return F.conv2d( + F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), + weight, + self.bias, + self.stride, + _pair(0), + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def forward(self, input): + # train with QuantNoise and evaluate the fully quantized network + p = self.p if self.training else 1 + + # update parameters every 100 iterations + if self.counter % self.update_step == 0: + self.scale = None + self.zero_point = None + self.counter += 1 + + # quantize weight + weight_quantized, self.scale, self.zero_point = emulate_int( + self.weight.detach(), + bits=self.bits, + method=self.method, + scale=self.scale, + zero_point=self.zero_point, + ) + + # mask to apply noise + mask = torch.zeros_like(self.weight) + mask.bernoulli_(1 - p) + noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0) + + # using straight-through estimator (STE) + clamp_low = -self.scale * self.zero_point + clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point) + weight = ( + torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + + noise.detach() + ) + + # return output + output = self._conv_forward(input, weight) + return output + + def extra_repr(self): + return ( + "in_channels={}, out_channels={}, kernel_size={}, stride={}, " + "padding={}, dilation={}, groups={}, bias={}, quant_noise={}, " + "bits={}, method={}".format( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.groups, + self.bias is not None, + self.p, + self.bits, + self.method, + ) + ) diff --git a/fairseq/modules/quantization/scalar/modules/qemb.py b/fairseq/modules/quantization/scalar/modules/qemb.py new file mode 100644 index 0000000..d6cf06e --- /dev/null +++ b/fairseq/modules/quantization/scalar/modules/qemb.py @@ -0,0 +1,147 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..ops import emulate_int + + +class IntEmbedding(nn.Module): + """ + Quantized counterpart of the nn.Embedding module that applies QuantNoise during training. + + Args: + - num_embeddings: number of tokens + - embedding_dim: embedding dimension + - p: amount of noise to inject (0 = no quantization, 1 = quantize all the weights) + - bits: number of bits + - method: choose among {"tensor", "histogram", "channel"} + - update_step: recompute scale and zero_point every update_steps iterations + + Remarks: + - We use the straight-through estimator so that the gradients + back-propagate nicely in the network, this is implemented with + the detach() trick + - Parameters scale and zero_point are recomputed every update_step + forward pass to reduce the overhead + - At test time, the weights are fully quantized + """ + + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + p=0, + update_step=1000, + bits=8, + method="histogram", + ): + super(IntEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if _weight is None: + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + else: + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" + self.weight = nn.Parameter(_weight) + self.sparse = sparse + + # quantization parameters + self.p = p + self.bits = bits + self.method = method + self.update_step = update_step + self.counter = 0 + + def reset_parameters(self): + nn.init.normal_(self.weight) + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input): + # train with QuantNoise and evaluate the fully quantized network + p = self.p if self.training else 1 + + # update parameters every 1000 iterations + if self.counter % self.update_step == 0: + self.scale = None + self.zero_point = None + self.counter += 1 + + # quantize weight + weight_quantized, self.scale, self.zero_point = emulate_int( + self.weight.detach(), + bits=self.bits, + method=self.method, + scale=self.scale, + zero_point=self.zero_point, + ) + + # mask to apply noise + mask = torch.zeros_like(self.weight) + mask.bernoulli_(1 - p) + noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0) + + # using straight-through estimator (STE) + clamp_low = -self.scale * self.zero_point + clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point) + weight = ( + torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + + noise.detach() + ) + + # return output + output = F.embedding( + input, + weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + return output + + def extra_repr(self): + s = "{num_embeddings}, {embedding_dim}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.max_norm is not None: + s += ", max_norm={max_norm}" + if self.norm_type != 2: + s += ", norm_type={norm_type}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + if self.sparse is not False: + s += ", sparse=True" + s += "quant_noise={p}, bits={bits}, method={method}" + return s.format(**self.__dict__) diff --git a/fairseq/modules/quantization/scalar/modules/qlinear.py b/fairseq/modules/quantization/scalar/modules/qlinear.py new file mode 100644 index 0000000..9db1559 --- /dev/null +++ b/fairseq/modules/quantization/scalar/modules/qlinear.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..ops import emulate_int + + +class IntLinear(nn.Module): + """ + Quantized counterpart of the nn.Linear module that applies QuantNoise during training. + + Args: + - in_features: input features + - out_features: output features + - bias: bias or not + - p: amount of noise to inject (0 = no quantization, 1 = quantize all the weights) + - bits: number of bits + - method: choose among {"tensor", "histogram", "channel"} + - update_step: recompute scale and zero_point every update_steps iterations + + Remarks: + - We use the straight-through estimator so that the gradients + back-propagate nicely in the network, this is implemented with + the detach() trick. + - Parameters scale and zero_point are recomputed every update_step + forward pass to reduce the overhead + - At test time, the weights are fully quantized + """ + + def __init__( + self, + in_features, + out_features, + bias=True, + p=0, + update_step=3000, + bits=8, + method="histogram", + ): + super(IntLinear, self).__init__() + self.in_features = int(in_features) + self.out_features = int(out_features) + self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) + self.chosen_bias = bias + if self.chosen_bias: + self.bias = torch.nn.Parameter(torch.Tensor(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + # quantization parameters + self.p = p + self.bits = bits + self.method = method + self.update_step = update_step + self.counter = 0 + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.chosen_bias: + nn.init.constant_(self.bias, 0.0) + return + + def forward(self, input): + # train with QuantNoise and evaluate the fully quantized network + p = self.p if self.training else 1 + + # update parameters every 100 iterations + if self.counter % self.update_step == 0: + self.scale = None + self.zero_point = None + self.counter += 1 + + # quantize weight + weight_quantized, self.scale, self.zero_point = emulate_int( + self.weight.detach(), + bits=self.bits, + method=self.method, + scale=self.scale, + zero_point=self.zero_point, + ) + + # mask to apply noise + mask = torch.zeros_like(self.weight) + mask.bernoulli_(1 - p) + noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0) + + # using straight-through estimator (STE) + clamp_low = -self.scale * self.zero_point + clamp_high = self.scale * (2 ** self.bits - 1 - self.zero_point) + weight = ( + torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + + noise.detach() + ) + + # return output + output = F.linear(input, weight, self.bias) + return output + + def extra_repr(self): + return "in_features={}, out_features={}, bias={}, quant_noise={}, bits={}, method={}".format( + self.in_features, + self.out_features, + self.bias is not None, + self.p, + self.bits, + self.method, + ) diff --git a/fairseq/modules/quantization/scalar/ops.py b/fairseq/modules/quantization/scalar/ops.py new file mode 100644 index 0000000..2a85515 --- /dev/null +++ b/fairseq/modules/quantization/scalar/ops.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def emulate_int(w, bits, method, scale=None, zero_point=None): + q = globals()[f"emulate_int{bits}_{method}"] + return q(w, scale=scale, zero_point=zero_point) + + +def quantize(w, scale, zero_point): + return ( + torch.clamp(torch.round(w / scale + zero_point), 0, 255) - zero_point + ) * scale + + +def emulate_int8_histogram(w, scale=None, zero_point=None): + if scale is None: + obs = torch.quantization.observer.HistogramObserver() + _ = obs(w.float()) + scale, zero_point = obs.calculate_qparams() + scale = scale.cuda().type_as(w) + zero_point = zero_point.cuda().type_as(w) + return quantize(w, scale, zero_point), scale, zero_point + + +def emulate_int8_channel(w, scale=None, zero_point=None): + if scale is None: + obs = torch.quantization.observer.PerChannelMinMaxObserver( + ch_axis=-1, qscheme=torch.per_channel_symmetric + ) + _ = obs(w) + scale, zero_point, ch_axis = obs.get_qparams() + scale = scale.cuda().type_as(w) + zero_point = zero_point.cuda().type_as(w) + return quantize(w, scale, zero_point), scale, zero_point + + +def emulate_int8_tensor(w, scale=None, zero_point=None): + if scale is None: + obs = torch.quantization.observer.MinMaxObserver() + _ = obs(w) + scale, zero_point = obs.calculate_qparams() + scale = scale.cuda().type_as(w) + zero_point = zero_point.cuda().type_as(w) + return quantize(w, scale, zero_point), scale, zero_point diff --git a/fairseq/modules/quantization/scalar/utils.py b/fairseq/modules/quantization/scalar/utils.py new file mode 100644 index 0000000..32cf616 --- /dev/null +++ b/fairseq/modules/quantization/scalar/utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from operator import attrgetter + +import torch.distributed as dist +import torch.nn as nn + +from ..pq.utils import attrsetter, get_layers +from .modules import ActivationQuantizer, IntConv2d, IntEmbedding, IntLinear + + +MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d} + + +def quantize_model_(model, p=0.2, bits=8, update_step=3000): + """ + Replaces all modules with their scalar quantized counterpart and + registers hooks to quantize the post-ativations of those modules. + + Args: + - model: a nn.Module + - p: amount of noise (0 for no noise, 1 to quantize all the weights/activations) + - bits: number of bits + - update_step: update quantization parameters every update_step steps + """ + + # quantize all layers + quantized_layers = get_layers(model, "(.*?)") + + for layer in quantized_layers: + + # book-keeping + is_master_process = (not dist.is_initialized()) or ( + dist.is_initialized() and dist.get_rank() == 0 + ) + + # recover module + module = attrgetter(layer)(model) + if is_master_process: + logging.info( + f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}" + ) + + # quantization params + q_params = { + "p": p, + "update_step": update_step, + "bits": bits, + "method": "histogram", + "counter": 0, + } + + # instantiate the quantized counterpart + if isinstance(module, tuple(MAPPING.keys())): + QuantizedModule = MAPPING[module.__class__] + quantized_module = QuantizedModule.__new__(QuantizedModule) + params = module.__dict__ + params.update(q_params) + quantized_module.__dict__.update(params) + + else: + if is_master_process: + logging.info(f"Module {module} not yet supported for quantization") + continue + + # activation quantization + a_q = ActivationQuantizer(quantized_module, p=0, bits=bits, method="histogram") + + # replace layer by its quantized counterpart + attrsetter(layer)(model, quantized_module) + + # return name of quantized layers + return quantized_layers diff --git a/fairseq/modules/same_pad.py b/fairseq/modules/same_pad.py new file mode 100644 index 0000000..b46f94d --- /dev/null +++ b/fairseq/modules/same_pad.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from torch import nn + + +class SamePad(nn.Module): + def __init__(self, kernel_size): + super().__init__() + self.remove = kernel_size % 2 == 0 + + def forward(self, x): + if self.remove: + x = x[:, :, :-1] + return x diff --git a/fairseq/modules/scalar_bias.py b/fairseq/modules/scalar_bias.py new file mode 100644 index 0000000..c96247c --- /dev/null +++ b/fairseq/modules/scalar_bias.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch + + +class ScalarBias(torch.autograd.Function): + """ + Adds a vector of scalars, used in self-attention mechanism to allow + the model to optionally attend to this vector instead of the past + """ + + @staticmethod + def forward(ctx, input, dim, bias_init): + size = list(input.size()) + size[dim] += 1 + output = input.new(*size).fill_(bias_init) + output.narrow(dim, 1, size[dim] - 1).copy_(input) + ctx.dim = dim + return output + + @staticmethod + def backward(ctx, grad): + return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None + + +def scalar_bias(input, dim, bias_init=0): + return ScalarBias.apply(input, dim, bias_init) diff --git a/fairseq/modules/sinusoidal_positional_embedding.py b/fairseq/modules/sinusoidal_positional_embedding.py new file mode 100644 index 0000000..857830f --- /dev/null +++ b/fairseq/modules/sinusoidal_positional_embedding.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional + +import torch +import torch.onnx.operators +from fairseq import utils +from torch import Tensor, nn + + +class SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length. + + Padding symbols are ignored. + """ + + def __init__(self, embedding_dim, padding_idx, init_size=1024): + super().__init__() + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.weights = SinusoidalPositionalEmbedding.get_embedding( + init_size, embedding_dim, padding_idx + ) + self.onnx_trace = False + self.register_buffer("_float_tensor", torch.FloatTensor(1)) + self.max_positions = int(1e5) + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + @staticmethod + def get_embedding( + num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None + ): + """Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( + 1 + ) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( + num_embeddings, -1 + ) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + def forward( + self, + input, + incremental_state: Optional[Any] = None, + timestep: Optional[Tensor] = None, + positions: Optional[Any] = None, + ): + """Input is expected to be of size [bsz x seqlen].""" + bspair = torch.onnx.operators.shape_as_tensor(input) + bsz, seq_len = bspair[0], bspair[1] + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + # recompute/expand embeddings if needed + self.weights = SinusoidalPositionalEmbedding.get_embedding( + max_pos, self.embedding_dim, self.padding_idx + ) + self.weights = self.weights.to(self._float_tensor) + + if incremental_state is not None: + # positions is the same for every token when decoding a single step + pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len + if self.onnx_trace: + return ( + self.weights.index_select(index=self.padding_idx + pos, dim=0) + .unsqueeze(1) + .repeat(bsz, 1, 1) + ) + return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + + positions = utils.make_positions( + input, self.padding_idx, onnx_trace=self.onnx_trace + ) + if self.onnx_trace: + flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) + embedding_shape = torch.cat( + (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) + ) + embeddings = torch.onnx.operators.reshape_from_tensor_shape( + flat_embeddings, embedding_shape + ) + return embeddings + return ( + self.weights.index_select(0, positions.view(-1)) + .view(bsz, seq_len, -1) + .detach() + ) diff --git a/fairseq/modules/sparse_multihead_attention.py b/fairseq/modules/sparse_multihead_attention.py new file mode 100644 index 0000000..3cbd9d6 --- /dev/null +++ b/fairseq/modules/sparse_multihead_attention.py @@ -0,0 +1,140 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch + +from .multihead_attention import MultiheadAttention + + +class SparseMultiheadAttention(MultiheadAttention): + """Sparse Multi-Headed Attention. + + "Generating Long Sequences with Sparse Transformers". Implements + fixed factorized self attention, where l=stride and c=expressivity. + A(1) includes all words in the stride window and A(2) takes a summary of c + words from the end of each stride window. + If is_bidirectional=False, we do not include any words past the current word, + as in the paper. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + stride=32, + expressivity=8, + is_bidirectional=True, + ): + + super().__init__( + embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn, + self_attention, + encoder_decoder_attention, + ) + + self.is_bidirectional = is_bidirectional + self.stride = stride + self.expressivity = expressivity + assert self.stride > 0 and self.stride >= self.expressivity + + # Used for Ai(2) calculations - beginning of [l-c, l] range + def compute_checkpoint(self, word_index): + if word_index % self.stride == 0 and word_index != 0: + checkpoint_index = word_index - self.expressivity + else: + checkpoint_index = ( + math.floor(word_index / self.stride) * self.stride + + self.stride + - self.expressivity + ) + return checkpoint_index + + # Computes Ai(2) + def compute_subset_summaries(self, absolute_max): + checkpoint_index = self.compute_checkpoint(0) + subset_two = set() + while checkpoint_index <= absolute_max - 1: + summary = set( + range( + checkpoint_index, + min(checkpoint_index + self.expressivity + 1, absolute_max), + ) + ) + subset_two = subset_two.union(summary) + checkpoint_index = self.compute_checkpoint(checkpoint_index + self.stride) + return subset_two + + # Sparse Transformer Fixed Attention Pattern: https://arxiv.org/pdf/1904.10509.pdf + def compute_fixed_attention_subset(self, word_index, tgt_len): + # +1s account for range function; [min, max) -> [min, max] + if not self.is_bidirectional: + absolute_max = word_index + 1 + else: + absolute_max = tgt_len + + # Subset 1 - whole window + rounded_index = ( + math.floor((word_index + self.stride) / self.stride) * self.stride + ) + if word_index % self.stride == 0 and word_index != 0: + subset_one = set( + range(word_index - self.stride, min(absolute_max, word_index + 1)) + ) + else: + subset_one = set( + range( + max(0, rounded_index - self.stride), + min(absolute_max, rounded_index + 1), + ) + ) + + # Subset 2 - summary per window + # If bidirectional, subset 2 is the same for every index + subset_two = set() + if not self.is_bidirectional: + subset_two = self.compute_subset_summaries(absolute_max) + + return subset_one.union(subset_two) + + # Compute sparse mask - if bidirectional, can pre-compute and store + def buffered_sparse_mask(self, tensor, tgt_len, src_len): + assert tgt_len > self.stride + sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float("-inf")) + + # If bidirectional, subset 2 is the same for every index + subset_summaries = set() + if self.is_bidirectional: + subset_summaries = self.compute_subset_summaries(tgt_len) + + for i in range(tgt_len): + fixed_attention_subset = self.compute_fixed_attention_subset(i, tgt_len) + fixed_attention_subset = fixed_attention_subset.union(subset_summaries) + included_word_indices = torch.LongTensor(list(fixed_attention_subset)) + sparse_mask[i].index_fill_(0, included_word_indices, 0) + return sparse_mask.type_as(tensor) + + def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): + sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len) + sparse_mask = sparse_mask.unsqueeze(0).expand( + bsz * self.num_heads, tgt_len, src_len + ) + attn_weights += sparse_mask diff --git a/fairseq/modules/sparse_transformer_sentence_encoder.py b/fairseq/modules/sparse_transformer_sentence_encoder.py new file mode 100644 index 0000000..f41ec09 --- /dev/null +++ b/fairseq/modules/sparse_transformer_sentence_encoder.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +from fairseq.modules import TransformerSentenceEncoder +from fairseq.modules.sparse_transformer_sentence_encoder_layer import ( + SparseTransformerSentenceEncoderLayer, +) + + +class SparseTransformerSentenceEncoder(TransformerSentenceEncoder): + """ + Sparse implementation of the TransformerSentenceEncoder + - see SparseMultiheadAttention + """ + + def __init__( + self, + padding_idx: int, + vocab_size: int, + num_encoder_layers: int = 6, + embedding_dim: int = 768, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + max_seq_len: int = 256, + num_segments: int = 2, + use_position_embeddings: bool = True, + offset_positions_by_padding: bool = True, + encoder_normalize_before: bool = False, + apply_bert_init: bool = False, + activation_fn: str = "relu", + learned_pos_embedding: bool = True, + embed_scale: float = None, + freeze_embeddings: bool = False, + n_trans_layers_to_freeze: int = 0, + export: bool = False, + is_bidirectional: bool = True, + stride: int = 32, + expressivity: int = 8, + ) -> None: + + super().__init__( + padding_idx, + vocab_size, + num_encoder_layers, + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + max_seq_len, + num_segments, + use_position_embeddings, + offset_positions_by_padding, + encoder_normalize_before, + apply_bert_init, + activation_fn, + learned_pos_embedding, + embed_scale, + freeze_embeddings, + n_trans_layers_to_freeze, + export, + ) + + self.layers = nn.ModuleList( + [ + SparseTransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + is_bidirectional=is_bidirectional, + stride=stride, + expressivity=expressivity, + ) + for _ in range(num_encoder_layers) + ] + ) + + def freeze_module_params(m): + if m is not None: + for p in m.parameters(): + p.requires_grad = False + + for layer in range(n_trans_layers_to_freeze): + freeze_module_params(self.layers[layer]) diff --git a/fairseq/modules/sparse_transformer_sentence_encoder_layer.py b/fairseq/modules/sparse_transformer_sentence_encoder_layer.py new file mode 100644 index 0000000..d95da59 --- /dev/null +++ b/fairseq/modules/sparse_transformer_sentence_encoder_layer.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.modules import TransformerSentenceEncoderLayer +from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention + + +class SparseTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): + """ + Implements a Sprase Transformer Encoder Layer (see SparseMultiheadAttention) + """ + + def __init__( + self, + embedding_dim: int = 768, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + export: bool = False, + is_bidirectional: bool = True, + stride: int = 32, + expressivity: int = 8, + ) -> None: + + super().__init__( + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + activation_fn, + export, + ) + + self.self_attn = SparseMultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + add_bias_kv=False, + add_zero_attn=False, + self_attention=True, + is_bidirectional=is_bidirectional, + stride=stride, + expressivity=expressivity, + ) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py new file mode 100644 index 0000000..6f3c79d --- /dev/null +++ b/fairseq/modules/transformer_layer.py @@ -0,0 +1,414 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from fairseq import utils +from fairseq.modules import LayerNorm, MultiheadAttention +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise +from torch import Tensor + + +class TransformerEncoderLayer(nn.Module): + """Encoder layer block. + + In the original paper each operation (multi-head attention or FFN) is + postprocessed with: `dropout -> add residual -> layernorm`. In the + tensor2tensor code they suggest that learning is more robust when + preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.encoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + + def __init__(self, args): + super().__init__() + self.embed_dim = args.encoder_embed_dim + self.quant_noise = getattr(args, 'quant_noise_pq', 0) + self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 + self.self_attn = self.build_self_attention(self.embed_dim, args) + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, 'activation_fn', 'relu') or "relu" + ) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 + if activation_dropout_p == 0: + # for backwards compatibility with models that use args.relu_dropout + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.normalize_before = args.encoder_normalize_before + self.fc1 = self.build_fc1( + self.embed_dim, + args.encoder_ffn_embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + args.encoder_ffn_embed_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + + self.final_layer_norm = LayerNorm(self.embed_dim) + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise( + nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size + ) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise( + nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size + ) + + def build_self_attention(self, embed_dim, args): + return MultiheadAttention( + embed_dim, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + ) + + def residual_connection(self, x, residual): + return residual + x + + def upgrade_state_dict_named(self, state_dict, name): + """ + Rename layer norm states from `...layer_norms.0.weight` to + `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to + `...final_layer_norm.weight` + """ + layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layer_norms.{}.{}".format(name, old, m) + if k in state_dict: + state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] + del state_dict[k] + + def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor): binary ByteTensor of shape + `(batch, seq_len)` where padding elements are indicated by ``1``. + attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, + where `tgt_len` is the length of output and `src_len` is the + length of input, though here both are equal to `seq_len`. + `attn_mask[tgt_i, src_j] = 1` means that when calculating the + embedding for `tgt_i`, we exclude (mask out) `src_j`. This is + useful for strided self-attention. + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + # anything in original attn_mask = 1, becomes -1e8 + # anything in original attn_mask = 0, becomes 0 + # Note that we cannot use -inf here, because at some edge cases, + # the attention weight (before softmax) for some padded element in query + # will become -inf, which results in NaN in model parameters + if attn_mask is not None: + attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + attn_mask=attn_mask, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + return x + + +class TransformerDecoderLayer(nn.Module): + """Decoder layer block. + + In the original paper each operation (multi-head attention, encoder + attention or FFN) is postprocessed with: `dropout -> add residual -> + layernorm`. In the tensor2tensor code they suggest that learning is more + robust when preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.decoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): + super().__init__() + self.embed_dim = args.decoder_embed_dim + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.quant_noise = getattr(args, "quant_noise_pq", 0) + self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) + + self.cross_self_attention = getattr(args, "cross_self_attention", False) + + self.self_attn = self.build_self_attention( + self.embed_dim, + args, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + self.activation_fn = utils.get_activation_fn( + activation=str(args.activation_fn) + if getattr(args, "activation_fn", None) is not None + else "relu" + ) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 + if activation_dropout_p == 0: + # for backwards compatibility with models that use args.relu_dropout + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.normalize_before = args.decoder_normalize_before + + # use layerNorm rather than FusedLayerNorm for exporting. + # char_inputs can be used to determint this. + # TODO remove this once we update apex with the fix + export = getattr(args, "char_inputs", False) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + self.fc1 = self.build_fc1( + self.embed_dim, + args.decoder_ffn_embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + args.decoder_ffn_embed_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=export) + self.need_attn = True + + self.onnx_trace = False + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_self_attention( + self, embed_dim, args, add_bias_kv=False, add_zero_attn=False + ): + return MultiheadAttention( + embed_dim, + args.decoder_attention_heads, + dropout=args.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=not getattr(args, "cross_self_attention", False), + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + ) + + def build_encoder_attention(self, embed_dim, args): + return MultiheadAttention( + embed_dim, + args.decoder_attention_heads, + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), + dropout=args.attention_dropout, + encoder_decoder_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + ) + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def residual_connection(self, x, residual): + return residual + x + + def forward( + self, + x, + encoder_out: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + prev_self_attn_state: Optional[List[torch.Tensor]] = None, + prev_attn_state: Optional[List[torch.Tensor]] = None, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + need_attn: bool = False, + need_head_weights: bool = False, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + need_attn (bool, optional): return attention weights + need_head_weights (bool, optional): return attention weights + for each head (default: return average over heads). + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + if need_head_weights: + need_attn = True + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + if prev_self_attn_state is not None: + prev_key, prev_value = prev_self_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_self_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] + assert incremental_state is not None + self.self_attn._set_input_buffer(incremental_state, saved_state) + _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) + if self.cross_self_attention and not ( + incremental_state is not None + and _self_attn_input_buffer is not None + and "prev_key" in _self_attn_input_buffer + ): + if self_attn_mask is not None: + assert encoder_out is not None + self_attn_mask = torch.cat( + (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 + ) + if self_attn_padding_mask is not None: + if encoder_padding_mask is None: + assert encoder_out is not None + encoder_padding_mask = self_attn_padding_mask.new_zeros( + encoder_out.size(1), encoder_out.size(0) + ) + self_attn_padding_mask = torch.cat( + (encoder_padding_mask, self_attn_padding_mask), dim=1 + ) + assert encoder_out is not None + y = torch.cat((encoder_out, x), dim=0) + else: + y = x + + x, attn = self.self_attn( + query=x, + key=y, + value=y, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + if self.encoder_attn is not None and encoder_out is not None: + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + if self.onnx_trace and incremental_state is not None: + saved_state = self.self_attn._get_input_buffer(incremental_state) + assert saved_state is not None + if self_attn_padding_mask is not None: + self_attn_state = [ + saved_state["prev_key"], + saved_state["prev_value"], + saved_state["prev_key_padding_mask"], + ] + else: + self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] + return x, attn, self_attn_state + return x, attn, None + + def make_generation_fast_(self, need_attn: bool = False, **kwargs): + self.need_attn = need_attn diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py new file mode 100644 index 0000000..208488f --- /dev/null +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -0,0 +1,286 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from fairseq.modules import ( + FairseqDropout, + LayerDropModuleList, + LayerNorm, + MultiheadAttention, + PositionalEmbedding, + TransformerSentenceEncoderLayer, +) +from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + module.q_proj.weight.data.normal_(mean=0.0, std=0.02) + module.k_proj.weight.data.normal_(mean=0.0, std=0.02) + module.v_proj.weight.data.normal_(mean=0.0, std=0.02) + + +class TransformerSentenceEncoder(nn.Module): + """ + Implementation for a Bi-directional Transformer based Sentence Encoder used + in BERT/XLM style pre-trained models. + + This first computes the token embedding using the token embedding matrix, + position embeddings (if specified) and segment embeddings + (if specified). After applying the specified number of + TransformerEncoderLayers, it outputs all the internal states of the + encoder as well as the final representation associated with the first + token (usually CLS token). + + Input: + - tokens: B x T matrix representing sentences + - segment_labels: B x T matrix representing segment label for tokens + + Output: + - a tuple of the following: + - a list of internal model states used to compute the + predictions where each tensor has shape T x B x C + - sentence representation associated with first input token + in format B x C. + """ + + def __init__( + self, + padding_idx: int, + vocab_size: int, + num_encoder_layers: int = 6, + embedding_dim: int = 768, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + layerdrop: float = 0.0, + max_seq_len: int = 256, + num_segments: int = 2, + use_position_embeddings: bool = True, + offset_positions_by_padding: bool = True, + encoder_normalize_before: bool = False, + apply_bert_init: bool = False, + activation_fn: str = "relu", + learned_pos_embedding: bool = True, + embed_scale: float = None, + freeze_embeddings: bool = False, + n_trans_layers_to_freeze: int = 0, + export: bool = False, + traceable: bool = False, + q_noise: float = 0.0, + qn_block_size: int = 8, + ) -> None: + + super().__init__() + self.padding_idx = padding_idx + self.vocab_size = vocab_size + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.layerdrop = layerdrop + self.max_seq_len = max_seq_len + self.embedding_dim = embedding_dim + self.num_segments = num_segments + self.use_position_embeddings = use_position_embeddings + self.apply_bert_init = apply_bert_init + self.learned_pos_embedding = learned_pos_embedding + self.traceable = traceable + self.tpu = False # whether we're on TPU + + self.embed_tokens = self.build_embedding( + self.vocab_size, self.embedding_dim, self.padding_idx + ) + self.embed_scale = embed_scale + + if q_noise > 0: + self.quant_noise = apply_quant_noise_( + nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), + q_noise, + qn_block_size, + ) + else: + self.quant_noise = None + + self.segment_embeddings = ( + nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None) + if self.num_segments > 0 + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + self.max_seq_len, + self.embedding_dim, + padding_idx=(self.padding_idx if offset_positions_by_padding else None), + learned=self.learned_pos_embedding, + ) + if self.use_position_embeddings + else None + ) + + if self.layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + self.build_transformer_sentence_encoder_layer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=self.dropout_module.p, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + for _ in range(num_encoder_layers) + ] + ) + + if encoder_normalize_before: + self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) + else: + self.emb_layer_norm = None + + # Apply initialization of model params after building the model + if self.apply_bert_init: + self.apply(init_bert_params) + + def freeze_module_params(m): + if m is not None: + for p in m.parameters(): + p.requires_grad = False + + if freeze_embeddings: + freeze_module_params(self.embed_tokens) + freeze_module_params(self.segment_embeddings) + freeze_module_params(self.embed_positions) + freeze_module_params(self.emb_layer_norm) + + for layer in range(n_trans_layers_to_freeze): + freeze_module_params(self.layers[layer]) + + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return nn.Embedding(vocab_size, embedding_dim, padding_idx) + + def build_transformer_sentence_encoder_layer( + self, + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + activation_fn, + export, + q_noise, + qn_block_size, + ): + return TransformerSentenceEncoderLayer( + embedding_dim=embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + + def prepare_for_tpu_(self, **kwargs): + self.tpu = True + + def forward( + self, + tokens: torch.Tensor, + segment_labels: torch.Tensor = None, + last_state_only: bool = False, + positions: Optional[torch.Tensor] = None, + token_embeddings: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + # compute padding mask. This is needed for multi-head attention + padding_mask = tokens.eq(self.padding_idx) + if not self.traceable and not self.tpu and not padding_mask.any(): + padding_mask = None + + if token_embeddings is not None: + x = token_embeddings + else: + x = self.embed_tokens(tokens) + + if self.embed_scale is not None: + x = x * self.embed_scale + + if self.embed_positions is not None: + x = x + self.embed_positions(tokens, positions=positions) + + if self.segment_embeddings is not None and segment_labels is not None: + x = x + self.segment_embeddings(segment_labels) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.emb_layer_norm is not None: + x = self.emb_layer_norm(x) + + x = self.dropout_module(x) + + # account for padding while computing the representation + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + inner_states = [] + if not last_state_only: + inner_states.append(x) + + for layer in self.layers: + x, _ = layer(x, self_attn_padding_mask=padding_mask) + if not last_state_only: + inner_states.append(x) + + sentence_rep = x[0, :, :] + + if last_state_only: + inner_states = [x] + + if self.traceable: + return torch.stack(inner_states), sentence_rep + else: + return inner_states, sentence_rep diff --git a/fairseq/modules/transformer_sentence_encoder_layer.py b/fairseq/modules/transformer_sentence_encoder_layer.py new file mode 100644 index 0000000..3589c60 --- /dev/null +++ b/fairseq/modules/transformer_sentence_encoder_layer.py @@ -0,0 +1,134 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +import torch +import torch.nn as nn +from fairseq import utils +from fairseq.modules import LayerNorm, MultiheadAttention +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: int = 768, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + export: bool = False, + q_noise: float = 0.0, + qn_block_size: int = 8, + init_fn: Callable = None, + ) -> None: + super().__init__() + + if init_fn is not None: + init_fn() + + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.activation_dropout_module = FairseqDropout( + activation_dropout, module_name=self.__class__.__name__ + ) + + # Initialize blocks + self.activation_fn = utils.get_activation_fn(activation_fn) + self.self_attn = self.build_self_attention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) + + self.fc1 = self.build_fc1( + self.embedding_dim, + ffn_embedding_dim, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + self.fc2 = self.build_fc2( + ffn_embedding_dim, + self.embedding_dim, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_self_attention( + self, + embed_dim, + num_attention_heads, + dropout, + self_attention, + q_noise, + qn_block_size, + ): + return MultiheadAttention( + embed_dim, + num_attention_heads, + dropout=dropout, + self_attention=True, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer implementation. + """ + residual = x + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = self.dropout_module(x) + x = residual + x + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = residual + x + x = self.final_layer_norm(x) + return x, attn diff --git a/fairseq/modules/transpose_last.py b/fairseq/modules/transpose_last.py new file mode 100644 index 0000000..e578b3e --- /dev/null +++ b/fairseq/modules/transpose_last.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +transpose last 2 dimensions of the input +""" + +import torch.nn as nn + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) diff --git a/fairseq/modules/unfold.py b/fairseq/modules/unfold.py new file mode 100644 index 0000000..138272f --- /dev/null +++ b/fairseq/modules/unfold.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn.functional as F + + +def unfold1d(x, kernel_size, padding_l, pad_value=0): + """unfold T x B x C to T x B x C x K""" + if kernel_size > 1: + T, B, C = x.size() + x = F.pad( + x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value + ) + x = x.as_strided((T, B, C, kernel_size), (B * C, C, 1, B * C)) + else: + x = x.unsqueeze(3) + return x diff --git a/fairseq/modules/vggblock.py b/fairseq/modules/vggblock.py new file mode 100644 index 0000000..ee5ee19 --- /dev/null +++ b/fairseq/modules/vggblock.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +from collections.abc import Iterable +from itertools import repeat + +import torch +import torch.nn as nn + + +def _pair(v): + if isinstance(v, Iterable): + assert len(v) == 2, "len(v) != 2" + return v + return tuple(repeat(v, 2)) + + +def infer_conv_output_dim(conv_op, input_dim, sample_inchannel): + sample_seq_len = 200 + sample_bsz = 10 + x = torch.randn(sample_bsz, sample_inchannel, sample_seq_len, input_dim) + # N x C x H x W + # N: sample_bsz, C: sample_inchannel, H: sample_seq_len, W: input_dim + x = conv_op(x) + # N x C x H x W + x = x.transpose(1, 2) + # N x H x C x W + bsz, seq = x.size()[:2] + per_channel_dim = x.size()[3] + # bsz: N, seq: H, CxW the rest + return x.contiguous().view(bsz, seq, -1).size(-1), per_channel_dim + + +class VGGBlock(torch.nn.Module): + """ + VGG motibated cnn module https://arxiv.org/pdf/1409.1556.pdf + + Args: + in_channels: (int) number of input channels (typically 1) + out_channels: (int) number of output channels + conv_kernel_size: convolution channels + pooling_kernel_size: the size of the pooling window to take a max over + num_conv_layers: (int) number of convolution layers + input_dim: (int) input dimension + conv_stride: the stride of the convolving kernel. + Can be a single number or a tuple (sH, sW) Default: 1 + padding: implicit paddings on both sides of the input. + Can be a single number or a tuple (padH, padW). Default: None + layer_norm: (bool) if layer norm is going to be applied. Default: False + + Shape: + Input: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features) + Output: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features) + """ + + def __init__( + self, + in_channels, + out_channels, + conv_kernel_size, + pooling_kernel_size, + num_conv_layers, + input_dim, + conv_stride=1, + padding=None, + layer_norm=False, + ): + assert ( + input_dim is not None + ), "Need input_dim for LayerNorm and infer_conv_output_dim" + super(VGGBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_kernel_size = _pair(conv_kernel_size) + self.pooling_kernel_size = _pair(pooling_kernel_size) + self.num_conv_layers = num_conv_layers + self.padding = ( + tuple(e // 2 for e in self.conv_kernel_size) + if padding is None + else _pair(padding) + ) + self.conv_stride = _pair(conv_stride) + + self.layers = nn.ModuleList() + for layer in range(num_conv_layers): + conv_op = nn.Conv2d( + in_channels if layer == 0 else out_channels, + out_channels, + self.conv_kernel_size, + stride=self.conv_stride, + padding=self.padding, + ) + self.layers.append(conv_op) + if layer_norm: + conv_output_dim, per_channel_dim = infer_conv_output_dim( + conv_op, input_dim, in_channels if layer == 0 else out_channels + ) + self.layers.append(nn.LayerNorm(per_channel_dim)) + input_dim = per_channel_dim + self.layers.append(nn.ReLU()) + + if self.pooling_kernel_size is not None: + pool_op = nn.MaxPool2d(kernel_size=self.pooling_kernel_size, ceil_mode=True) + self.layers.append(pool_op) + self.total_output_dim, self.output_dim = infer_conv_output_dim( + pool_op, input_dim, out_channels + ) + + def forward(self, x): + for i, _ in enumerate(self.layers): + x = self.layers[i](x) + return x diff --git a/fairseq/nan_detector.py b/fairseq/nan_detector.py new file mode 100644 index 0000000..faa8031 --- /dev/null +++ b/fairseq/nan_detector.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + + +logger = logging.getLogger(__name__) + + +class NanDetector: + """ + Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name + """ + + def __init__(self, model, forward=True, backward=True): + self.bhooks = [] + self.fhooks = [] + self.forward = forward + self.backward = backward + self.named_parameters = list(model.named_parameters()) + self.reset() + + for name, mod in model.named_modules(): + mod.__module_name = name + self.add_hooks(mod) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + # Dump out all model gnorms to enable better debugging + norm = {} + gradients = {} + for name, param in self.named_parameters: + if param.grad is not None: + grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32) + norm[name] = grad_norm.item() + if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): + gradients[name] = param.grad.data + if len(gradients) > 0: + logger.info("Detected nan/inf grad norm, dumping norms...") + logger.info(f"norms: {norm}") + logger.info(f"gradients: {gradients}") + + self.close() + + def add_hooks(self, module): + if self.forward: + self.fhooks.append(module.register_forward_hook(self.fhook_fn)) + if self.backward: + self.bhooks.append(module.register_backward_hook(self.bhook_fn)) + + def reset(self): + self.has_printed_f = False + self.has_printed_b = False + + def _detect(self, tensor, name, backward): + err = None + if ( + torch.is_floating_point(tensor) + # single value tensors (like the loss) will not provide much info + and tensor.numel() >= 2 + ): + with torch.no_grad(): + if torch.isnan(tensor).any(): + err = "NaN" + elif torch.isinf(tensor).any(): + err = "Inf" + if err is not None: + err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}" + return err + + def _apply(self, module, inp, x, backward): + if torch.is_tensor(x): + if isinstance(inp, tuple) and len(inp) > 0: + inp = inp[0] + err = self._detect(x, module.__module_name, backward) + if err is not None: + if torch.is_tensor(inp) and not backward: + err += ( + f" input max: {inp.max().item()}, input min: {inp.min().item()}" + ) + + has_printed_attr = "has_printed_b" if backward else "has_printed_f" + logger.warning(err) + setattr(self, has_printed_attr, True) + elif isinstance(x, dict): + for v in x.values(): + self._apply(module, inp, v, backward) + elif isinstance(x, list) or isinstance(x, tuple): + for v in x: + self._apply(module, inp, v, backward) + + def fhook_fn(self, module, inp, output): + if not self.has_printed_f: + self._apply(module, inp, output, backward=False) + + def bhook_fn(self, module, inp, output): + if not self.has_printed_b: + self._apply(module, inp, output, backward=True) + + def close(self): + for hook in self.fhooks + self.bhooks: + hook.remove() diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py new file mode 100644 index 0000000..d8e5817 --- /dev/null +++ b/fairseq/optim/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os + +from fairseq import registry +from fairseq.optim.bmuf import FairseqBMUF # noqa +from fairseq.optim.fairseq_optimizer import ( # noqa + FairseqOptimizer, + LegacyFairseqOptimizer, +) +from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer +from fairseq.optim.shard import shard_ +from omegaconf import DictConfig + +__all__ = [ + "FairseqOptimizer", + "FP16Optimizer", + "MemoryEfficientFP16Optimizer", + "shard_", +] + +( + _build_optimizer, + register_optimizer, + OPTIMIZER_REGISTRY, + OPTIMIZER_DATACLASS_REGISTRY, +) = registry.setup_registry("--optimizer", base_class=FairseqOptimizer, required=True) + + +def build_optimizer( + cfg: DictConfig, params, *extra_args, **extra_kwargs +): + if all(isinstance(p, dict) for p in params): + params = [t for p in params for t in p.values()] + params = list(filter(lambda p: p.requires_grad, params)) + return _build_optimizer(cfg, params, *extra_args, **extra_kwargs) + + +# automatically import any Python files in the optim/ directory +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.optim." + file_name) diff --git a/fairseq/optim/adadelta.py b/fairseq/optim/adadelta.py new file mode 100644 index 0000000..f1a2154 --- /dev/null +++ b/fairseq/optim/adadelta.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adadelta") +class Adadelta(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO', + help='coefficient used for computing a running average of squared gradients') + parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS', + help='term added to the denominator to improve numerical stability') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "rho": self.args.adadelta_rho, + "eps": self.args.adadelta_eps, + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + return True diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py new file mode 100644 index 0000000..91745ce --- /dev/null +++ b/fairseq/optim/adafactor.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adafactor") +class FairseqAdafactor(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = Adafactor(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--adafactor-eps', default='(1e-30, 1e-3)', metavar="E", + help='epsilons for Adafactor optimizer') + parser.add_argument('--clip-threshold', type=float, default=1.0, metavar="C", + help='threshold for clipping update root mean square') + parser.add_argument('--decay-rate', type=float, default=-0.8, metavar="D", + help='decay rate of the second moment estimator') + parser.add_argument('--beta1', type=float, default=None, metavar="B", + help='beta for first moment estimator. Optional') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--scale-parameter', action='store_true', + help='scale learning rate by root mean square of parameter') + parser.add_argument('--relative-step', action='store_true', + help='set learning rate to inverse square root of timestep,' + 'otherwise use external learning rate') + parser.add_argument('--warmup-init', action='store_true', + help='use relative step for warm-up learning rate schedule') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + Note : Convergence issues empirically observed with fp16 on. + Might require search for appropriate configuration. + """ + return { + "lr": self.args.lr[0], + "eps": eval(self.args.adafactor_eps), + "clip_threshold": self.args.clip_threshold, + "decay_rate": self.args.decay_rate, + "beta1": self.args.beta1, + "weight_decay": self.args.weight_decay, + "scale_parameter": self.args.scale_parameter, # defaults to False + "relative_step": self.args.relative_step, # defaults to False + "warmup_init": self.args.warmup_init, + } + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + + This implementation is based on: + `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + + Note that this optimizer internally adjusts the learning rate + depending on the *scale_parameter*, *relative_step* and + *warmup_init* options. To use a manual (external) learning rate + schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constans for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold (float): threshold of root mean square of + final gradient update (default: 1.0) + decay_rate (float): coefficient used to compute running averages of square + gradient (default: -0.8) + beta1 (float): coefficient used for computing running averages of gradient + (default: None) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + scale_parameter (bool): if True, learning rate is scaled by root mean square of + parameter (default: True) + relative_step (bool): if True, time-dependent learning rate is computed + instead of external learning rate (default: True) + warmup_init (bool): time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + """ + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and relative_step: + raise ValueError("Cannot combine manual lr and relative_step options") + if warmup_init and not relative_step: + raise ValueError("warmup_init requires relative_step=True") + + defaults = dict( + lr=lr, + eps=eps, + clip_threshold=clip_threshold, + decay_rate=decay_rate, + beta1=beta1, + weight_decay=weight_decay, + scale_parameter=scale_parameter, + relative_step=relative_step, + warmup_init=warmup_init, + ) + super(Adafactor, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return False + + def _get_lr(self, param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = ( + 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + ) + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + def _get_options(self, param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + def _rms(self, tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = ( + (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) + .rsqrt_() + .unsqueeze(-1) + ) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:] + ).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + group["lr"] = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad ** 2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_( + update.mean(dim=-1), alpha=1.0 - beta2t + ) + exp_avg_sq_col.mul_(beta2t).add_( + update.mean(dim=-2), alpha=1.0 - beta2t + ) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_( + (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0) + ) + update.mul_(group["lr"]) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) + + p_data_fp32.add_(-update) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss diff --git a/fairseq/optim/adagrad.py b/fairseq/optim/adagrad.py new file mode 100644 index 0000000..a79b6c3 --- /dev/null +++ b/fairseq/optim/adagrad.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adagrad") +class Adagrad(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + return True diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py new file mode 100644 index 0000000..9b8ddff --- /dev/null +++ b/fairseq/optim/adam.py @@ -0,0 +1,226 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from collections import Collection +from dataclasses import dataclass, field +from typing import List + +import torch +import torch.distributed as dist +import torch.optim +from fairseq.dataclass import FairseqDataclass +from fairseq.optim import FairseqOptimizer, register_optimizer +from fairseq.optim.fused_adam import get_fused_adam_class +from omegaconf import II, DictConfig + + +logger = logging.getLogger(__name__) + + +@dataclass +class FairseqAdamConfig(FairseqDataclass): + adam_betas: str = field( + default="(0.9, 0.999)", metadata={"help": "betas for Adam optimizer"} + ) + adam_eps: float = field( + default=1e-8, metadata={"help": "epsilon for Adam optimizer"} + ) + weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) + use_old_adam: bool = field( + default=False, metadata={"help": "Use fairseq.optim.adam.Adam"} + ) + # TODO common vars below in parent + tpu: bool = II("common.tpu") + lr: List[float] = II("optimization.lr") + + +@register_optimizer("adam", dataclass=FairseqAdamConfig) +class FairseqAdam(FairseqOptimizer): + """Adam optimizer for fairseq. + + Important note: this optimizer corresponds to the "AdamW" variant of + Adam in its weight decay behavior. As such, it is most closely + analogous to torch.optim.AdamW from PyTorch. + """ + + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) + fused_adam_cls = get_fused_adam_class() + use_fused_adam = ( + not getattr(cfg, "use_old_adam", False) + and fused_adam_cls is not None + and torch.cuda.is_available() + ) + if getattr(cfg, "tpu", False): + # on TPUs we use the Adam defined here, since it + # automatically casts gradients to FP32 + self._optimizer = Adam(params, **self.optimizer_config) + elif use_fused_adam: + logger.info("using FusedAdam") + self._optimizer = fused_adam_cls(params, **self.optimizer_config) + else: + self._optimizer = Adam(params, **self.optimizer_config) + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "betas": eval(self.cfg.adam_betas), + "eps": self.cfg.adam_eps, + "weight_decay": self.cfg.weight_decay, + } + + def average_params(self): + """Reduce Params is only used during BMUF distributed training.""" + state_dict = self.optimizer.state_dict() + total_gpus = float(dist.get_world_size()) + + for _, value in state_dict["state"].items(): + value["exp_avg"] /= total_gpus + value["exp_avg_sq"] /= total_gpus + dist.all_reduce(value["exp_avg"], op=dist.ReduceOp.SUM) + dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM) + + +class Adam(torch.optim.Optimizer): + """Implements Adam algorithm. + + This implementation is modified from torch.optim.Adam based on: + `Fixed Weight Decay Regularization in Adam` + (see https://arxiv.org/abs/1711.05101) + + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + ): + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad + ) + super(Adam, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + amsgrad = group.get("amsgrad", False) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p_data_fp32) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].to(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) + if amsgrad: + state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( + p_data_fp32 + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group["eps"]) + else: + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) + + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss diff --git a/fairseq/optim/adamax.py b/fairseq/optim/adamax.py new file mode 100644 index 0000000..577a688 --- /dev/null +++ b/fairseq/optim/adamax.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adamax") +class FairseqAdamax(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = Adamax(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--adamax-betas', default='(0.9, 0.999)', metavar='B', + help='betas for Adam optimizer') + parser.add_argument('--adamax-eps', type=float, default=1e-8, metavar='D', + help='epsilon for Adam optimizer') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--no-bias-correction', default=False, action='store_true', + help='disable bias correction') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "betas": eval(self.args.adamax_betas), + "eps": self.args.adamax_eps, + "weight_decay": self.args.weight_decay, + "bias_correction": not self.args.no_bias_correction, + } + + +class Adamax(torch.optim.Optimizer): + """Implements Adamax algorithm (a variant of Adam based on infinity norm). + + It has been proposed in `Adam: A Method for Stochastic Optimization`__. + + Compared to the version in PyTorch, this version implements a fix for weight decay. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + bias_correction (bool, optional): enable bias correction (default: True) + + __ https://arxiv.org/abs/1412.6980 + """ + + def __init__( + self, + params, + lr=2e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + bias_correction=True, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + bias_correction=bias_correction, + ) + super(Adamax, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("Adamax does not support sparse gradients") + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_inf"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].to(p_data_fp32) + state["exp_inf"] = state["exp_inf"].to(p_data_fp32) + + exp_avg, exp_inf = state["exp_avg"], state["exp_inf"] + beta1, beta2 = group["betas"] + eps = group["eps"] + + state["step"] += 1 + + # Update biased first moment estimate. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + # Update the exponentially weighted infinity norm. + torch.max( + exp_inf.mul_(beta2), + grad.abs_(), + out=exp_inf, + ) + + step_size = group["lr"] + if group["bias_correction"]: + bias_correction = 1 - beta1 ** state["step"] + step_size /= bias_correction + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) + + p_data_fp32.addcdiv_(exp_avg, exp_inf.add(eps), value=-step_size) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss diff --git a/fairseq/optim/bmuf.py b/fairseq/optim/bmuf.py new file mode 100644 index 0000000..d6d0e04 --- /dev/null +++ b/fairseq/optim/bmuf.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +import torch +import torch.distributed as dist +from fairseq.dataclass.configs import FairseqBMUFConfig +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.optim.fairseq_optimizer import FairseqOptimizer + + +class FairseqBMUF(FairseqOptimizer): + """ + Implements incremental block distributed data parallelism similar to + https://ieeexplore.ieee.org/document/7472805 + + Paper title: Scalable training of deep learning machines by incremental + block training with intra-block parallel optimization and blockwise + model-update filtering + """ + + def __init__(self, cfg: FairseqBMUFConfig, optimizer): + super().__init__(cfg) + self._optimizer = optimizer + self._num_updates = 0 + self.sync_iter = cfg.global_sync_iter + self.block_momentum = cfg.block_momentum + self.block_lr = cfg.block_lr + self._reset_local_data() + self.warmup_iteration = cfg.warmup_iterations + self.use_nbm = cfg.use_nbm + self.initial_state = self._optimizer.state_dict() + self.average_sync = self.cfg.average_sync + self.world_size = self.cfg.distributed_world_size + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + gen_parser_from_dataclass(parser, FairseqBMUFConfig()) + + @property + def optimizer(self): + return self._optimizer.optimizer + + @property + def optimizer_config(self): + return self._optimizer.optimizer_config + + def get_lr(self): + return self._optimizer.get_lr() + + def set_lr(self, lr): + self._optimizer.set_lr(lr) + + def state_dict(self): + return self._optimizer.state_dict() + + def load_state_dict(self, state_dict, optimizer_overrides=None): + self._optimizer.load_state_dict(state_dict, optimizer_overrides) + self.initial_state = self._optimizer.state_dict() + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + self._optimizer.multiply_grads(c) + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + return self._optimizer.clip_grad_norm(max_norm, aggregate_norm_fn) + + def average_params(self): + self._optimizer.average_params() + + def _block_sync(self): + if self.world_size <= 1: + return + # Update the global model using local models from all GPUs + # (Step-1) Calculate grad between previously synced model and + # currrent local model + if self.block_momentum != 0: + self._calc_grad() + + # (Step-2) Average gradient from all GPUs + self._avg_grad_from_all_gpus() + + # (Step-3) Calculate global momentum and update the global model + if self.block_momentum != 0: + self._update_global_model() + + # (Step-4) Average local optimizer params + if self.average_sync: + self.average_params() + + def _is_warmup_end(self): + # Check whether train iterations is equal to warmup iter + if self.get_num_updates() == self.warmup_iteration: + return True + return False + + def _is_bmuf_iter(self): + # Check whether train iterations is equal to bmuf sync iter + if (self.get_num_updates() > self.warmup_iteration) and ( + self.get_num_updates() % self.sync_iter == 0 + ): + return True + return False + + def _warmup_sync(self, root_rank=0): + if self.world_size <= 1: + return + # Broadcast the local model to all gpus + for param in self.params: + dist.broadcast(param.data, src=root_rank) + + # Update local optimizer state + if self.average_sync: + self._optimizer.average_params() + else: + self._optimizer.load_state_dict(self.initial_state) + + self._reset_local_data() + + def step(self, closure=None): + """Performs a single optimization step.""" + self._optimizer.step(closure) + self.set_num_updates(self.get_num_updates() + 1) + if self._is_warmup_end(): + self._warmup_sync() + elif self._is_bmuf_iter(): + self._block_sync() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + self._optimizer.zero_grad() + + def get_num_updates(self): + """Get the number of parameters updates.""" + return self._num_updates + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self._num_updates = num_updates + + @torch.no_grad() + def _reset_local_data(self): + # (Step-0) Initialize global momentum parameters and store global copy on each gpu + self.global_params = [torch.zeros_like(p.data) for p in self.params] + self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params] + self.grads = [p.data.new_zeros(p.data.size()) for p in self.params] + + # saving the global model locally for calculating gradient during bmuf sync + for param, global_param in zip(self.params, self.global_params): + global_param.copy_(param.data) + + @torch.no_grad() + def _calc_grad(self): + # global_params is basically the global copy from the previously finished + # synchronisation. param.data is local parameter after block_sync_freq + # for the local gpu. so grad is difference between previously synced + # model and currrent local model. + for index, (param, global_param) in enumerate( + zip(self.params, self.global_params) + ): + self.grads[index] = global_param - param.data + + def _avg_grad_from_all_gpus(self): + for index, param in enumerate(self.params): + sync_para = param.data if self.block_momentum == 0 else self.grads[index] + sync_para /= float(dist.get_world_size()) + dist.all_reduce(sync_para, op=dist.ReduceOp.SUM) + + @torch.no_grad() + def _update_global_model(self): + for index, (param, global_param, smoothed_grad, grad) in enumerate( + zip( + self.params, + self.global_params, + self.smoothed_grads, + # all gpus would share the same value of smoothed_grad, since it is + # always computed on synchronized gradients. + self.grads, + ) + ): + # global_param is basically last syncrhornized parameter. though + # smoothed_grad is local, all processes will have same value of + # smoothed_grad and hence param is globally synchronized copy. + # smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t) + smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad + param.data.copy_(global_param - smoothed_grad) + + # A Nesterov momentum here is to do a partial weight update before + # calculating the gradient + if self.use_nbm: + param.data.copy_(param.data - self.block_momentum * smoothed_grad) + + # backup for the next synchronization. + self.smoothed_grads[index] = smoothed_grad + global_param.copy_(param.data) diff --git a/fairseq/optim/dynamic_loss_scaler.py b/fairseq/optim/dynamic_loss_scaler.py new file mode 100644 index 0000000..c5da604 --- /dev/null +++ b/fairseq/optim/dynamic_loss_scaler.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +class DynamicLossScaler(object): + def __init__( + self, + init_scale=2.0 ** 15, + scale_factor=2.0, + scale_window=2000, + tolerance=0.05, + threshold=None, + min_loss_scale=1e-4, + ): + self.loss_scale = init_scale + self.scale_factor = scale_factor + self.scale_window = scale_window + self.tolerance = tolerance + self.threshold = threshold + self._iter = 0 + self._last_overflow_iter = -1 + self._last_rescale_iter = -1 + self._overflows_since_rescale = 0 + self.min_loss_scale = min_loss_scale + + def scale(self, outputs): + return self.loss_scale * outputs + + def update(self): + if (self._iter - self._last_overflow_iter) % self.scale_window == 0: + self.loss_scale *= self.scale_factor + self._last_rescale_iter = self._iter + self._iter += 1 + + def _decrease_loss_scale(self): + self.loss_scale /= self.scale_factor + if self.threshold is not None: + self.loss_scale = max(self.loss_scale, self.threshold) + + def check_overflow(self, grad_norm): + # detect inf and nan + if grad_norm == float("inf") or grad_norm != grad_norm: + # overflow has occured + prev_scale = self.loss_scale + iter_since_rescale = self._iter - self._last_rescale_iter + + self._last_overflow_iter = self._iter + self._overflows_since_rescale += 1 + pct_overflow = self._overflows_since_rescale / float(iter_since_rescale) + if pct_overflow >= self.tolerance: + self._decrease_loss_scale() + self._last_rescale_iter = self._iter + self._overflows_since_rescale = 0 + + if self.loss_scale <= self.min_loss_scale: + # Use FloatingPointError as an uncommon error that parent + # functions can safely catch to stop training. + self.loss_scale = prev_scale + raise FloatingPointError( + ( + "Minimum loss scale reached ({}). Your loss is probably exploding. " + "Try lowering the learning rate, using gradient clipping or " + "increasing the batch size." + ).format(self.min_loss_scale) + ) + + self._iter += 1 + raise OverflowError("setting loss scale to: " + str(self.loss_scale)) diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py new file mode 100644 index 0000000..9c09383 --- /dev/null +++ b/fairseq/optim/fairseq_optimizer.py @@ -0,0 +1,150 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq import utils +from fairseq.dataclass.utils import gen_parser_from_dataclass + + +class FairseqOptimizer(object): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + @classmethod + def add_args(cls, parser): + """Add optimizer-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + @property + def optimizer(self): + """Return a torch.optim.optimizer.Optimizer instance.""" + if not hasattr(self, "_optimizer"): + raise NotImplementedError + if not isinstance(self._optimizer, torch.optim.Optimizer): + raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") + return self._optimizer + + @optimizer.setter + def optimizer(self, optimizer): + """Reset optimizer instance.""" + if not hasattr(self, "_optimizer"): + raise NotImplementedError + if not isinstance(self._optimizer, torch.optim.Optimizer): + raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") + self._optimizer = optimizer + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + raise NotImplementedError + + @property + def params(self): + """Return an iterable of the parameters held by the optimizer.""" + for param_group in self.param_groups: + for p in param_group["params"]: + yield p + + @property + def param_groups(self): + return self.optimizer.param_groups + + def __getstate__(self): + return self._optimizer.__getstate__() + + def get_lr(self): + """Return the current learning rate.""" + return self.param_groups[0]["lr"] + + def set_lr(self, lr): + """Set the learning rate.""" + for param_group in self.param_groups: + param_group["lr"] = lr + + def state_dict(self): + """Return the optimizer's state dict.""" + return self.optimizer.state_dict() + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + self.optimizer.load_state_dict(state_dict) + + if optimizer_overrides is not None and len(optimizer_overrides) > 0: + # override learning rate, momentum, etc. with latest values + for group in self.param_groups: + group.update(optimizer_overrides) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" + loss.backward() + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + for p in self.params: + if p.grad is not None: + p.grad.data.mul_(c) + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) + + def step(self, closure=None, scale=1.0): + """Performs a single optimization step.""" + if self.supports_step_with_scale: + self.optimizer.step(closure, scale=scale) + else: + if scale != 1.0: + self.multiply_grads(1.0 / scale) + self.optimizer.step(closure) + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + for p in self.params: + p.grad = None + self.optimizer.zero_grad() + + @property + def supports_memory_efficient_fp16(self): + if hasattr(self.optimizer, "supports_memory_efficient_fp16"): + return self.optimizer.supports_memory_efficient_fp16 + return False + + @property + def supports_step_with_scale(self): + if hasattr(self.optimizer, "supports_step_with_scale"): + return self.optimizer.supports_step_with_scale + return False + + @property + def supports_flat_params(self): + """ + Whether the optimizer supports collapsing of the model + parameters/gradients into a single contiguous Tensor. + """ + if hasattr(self.optimizer, "supports_flat_params"): + return self.optimizer.supports_flat_params + return False + + def average_params(self): + pass + + +class LegacyFairseqOptimizer(FairseqOptimizer): + def __init__(self, args): + self.args = args diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py new file mode 100644 index 0000000..b08a723 --- /dev/null +++ b/fairseq/optim/fp16_optimizer.py @@ -0,0 +1,496 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from itertools import chain + +import torch +from fairseq import optim +from omegaconf import DictConfig + +from .dynamic_loss_scaler import DynamicLossScaler + + +class _FP16OptimizerMixin(object): + def __init__(self, *args, **kwargs): + # forward __init__ call to the next class in mro(method resolution order) + super().__init__(*args, **kwargs) + self._multiply_factor = 1.0 + + @property + def has_flat_params(self): + return torch.is_tensor(self.fp32_params) or ( + isinstance(self.fp32_params, dict) + and all(torch.is_tensor(t) for t in self.fp32_params.values()) + ) + + @classmethod + def build_fp32_params(cls, args, params, flatten=True): + # create FP32 copy of parameters and grads + if flatten: + is_pipeline_parallel = getattr( + args, "pipeline_model_parallel", False + ) and getattr(args, "distributed_no_spawn", False) + total_param_size = sum(p.data.numel() for p in params) + devices = [torch.cuda.current_device()] + if is_pipeline_parallel: + devices = list(set(args.pipeline_devices)) + fp32_params = {} + for device in devices: + if is_pipeline_parallel: + device_param_size = sum( + p.data.numel() for p in params if p.device.index == device + ) + device_params = [p for p in params if p.device.index == device] + else: + device_param_size = total_param_size + device_params = params + fp32_params[device] = ( + device_params[0].new(0).float().new(device_param_size) + ) + offset = 0 + for p in device_params: + numel = p.data.numel() + fp32_params[device][offset : offset + numel].copy_(p.data.view(-1)) + offset += numel + fp32_params[device] = torch.nn.Parameter(fp32_params[device]) + fp32_params[device].grad = fp32_params[device].data.new( + device_param_size + ) + return fp32_params + else: + fp32_params = [] + for p in params: + p32 = torch.nn.Parameter(p.data.float()) + p32.grad = torch.zeros_like(p32.data) + fp32_params.append(p32) + return fp32_params + + def state_dict(self): + """Return the optimizer's state dict.""" + state_dict = self.fp32_optimizer.state_dict() + if self.scaler is not None: + state_dict["loss_scale"] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + if "loss_scale" in state_dict and self.scaler is not None: + self.scaler.loss_scale = state_dict["loss_scale"] + self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + if self.scaler is not None: + loss = self.scaler.scale(loss) + loss.backward() + self._needs_sync = True + + def _sync_fp16_grads_to_fp32(self): + if self._needs_sync: + # copy FP16 grads to FP32 + if self.has_flat_params: + devices = list(self.fp32_params.keys()) + device_params_dict = defaultdict(list) + for p in self.fp16_params: + if p.requires_grad: + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + offset = 0 + for p in device_params: + grad_data = ( + p.grad.data + if p.grad is not None + else p.data.new_zeros(p.data.shape) + ) + numel = grad_data.numel() + self.fp32_params[device].grad.data[ + offset : offset + numel + ].copy_(grad_data.view(-1)) + offset += numel + else: + for p, p32 in zip(self.fp16_params, self.fp32_params): + if not p.requires_grad: + continue + if p.grad is not None: + p32.grad.data.copy_(p.grad.data) + else: + p32.grad = torch.zeros_like(p.data, dtype=torch.float) + + self._needs_sync = False + + def _sync_fp32_params_to_fp16(self): + # copy FP32 params back into FP16 model + if self.has_flat_params: + devices = list(self.fp32_params.keys()) + device_params_dict = defaultdict(list) + for p in self.fp16_params: + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + offset = 0 + for p in device_params: + numel = p.data.numel() + p.data.copy_( + self.fp32_params[device] + .data[offset : offset + numel] + .view_as(p.data) + ) + offset += numel + else: + for p, p32 in zip(self.fp16_params, self.fp32_params): + if not p.requires_grad: + continue + p.data.copy_(p32.data) + + def _unscale_grads(self): + self._sync_fp16_grads_to_fp32() + if self._multiply_factor != 1.0: + self.fp32_optimizer.multiply_grads(self._multiply_factor) + self._multiply_factor = 1.0 + + def multiply_grads(self, c): + """Multiplies grads by a constant ``c``.""" + self._multiply_factor *= c + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm and updates dynamic loss scaler.""" + self._sync_fp16_grads_to_fp32() + + grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm( + 0, aggregate_norm_fn + ) + + if self.scaler is not None: + if grad_norm > max_norm > 0.0: + self._multiply_factor *= max_norm / grad_norm + + self.scaler.check_overflow(grad_norm) + elif max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) + self._multiply_factor *= clip_coef + + return grad_norm + + def step(self, closure=None): + """Performs a single optimization step.""" + self._sync_fp16_grads_to_fp32() + + if getattr(self, "supports_step_with_scale", False): + self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor)) + else: + self._unscale_grads() + self.fp32_optimizer.step(closure) + + if self.scaler is not None: + self.scaler.update() + + self._sync_fp32_params_to_fp16() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + for p in self.fp16_params: + p.grad = None + if self.has_flat_params: + if torch.is_tensor(self.fp32_params): + self.fp32_params.grad.zero_() + elif isinstance(self.fp32_params, dict): + for fp32_params in self.fp32_params.values(): + fp32_params.grad.zero_() + else: + raise RuntimeError("self.fp32_params must be a tensor or dict") + else: + for p32 in self.fp32_params: + p32.grad.zero_() + self._needs_sync = False + + if self.scaler is not None: + self._multiply_factor = 1.0 / float(self.scaler.loss_scale) + + +class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): + """ + Wrap an *optimizer* to support FP16 (mixed precision) training. + """ + + def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs): + super().__init__(cfg.optimizer) + self.fp16_params = params + self.fp32_optimizer = fp32_optimizer + self.fp32_params = fp32_params + + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: + raise ValueError( + "--fp16-scale-window must be given explicitly when using a " + "custom --update-freq schedule" + ) + data_parallel_size = int( + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = int( + 2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0] + ) + else: + scale_window = cfg.common.fp16_scale_window + + if not getattr(cfg.common, "bf16", False): + self.scaler = DynamicLossScaler( + init_scale=cfg.common.fp16_init_scale, + scale_window=scale_window, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, + ) + else: + # disable loss scaling for bfloat16 + self.scaler = None + + @classmethod + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): + """ + Args: + cfg (omegaconf.DictConfig): fairseq args + params (iterable): iterable of parameters to optimize + """ + flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False) + if getattr(cfg.common, "bf16", False): + flatten = False # mixed precision is faster on TPUs without flat grads + fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten) + if flatten: + fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params]) + else: + fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params) + if flatten and not fp32_optimizer.supports_flat_params: + raise RuntimeError( + f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads" + ) + return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs) + + @property + def optimizer(self): + return self.fp32_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.fp32_optimizer.optimizer = optimizer + + @property + def optimizer_config(self): + return self.fp32_optimizer.optimizer_config + + def get_lr(self): + return self.fp32_optimizer.get_lr() + + def set_lr(self, lr): + self.fp32_optimizer.set_lr(lr) + + +class _MemoryEfficientFP16OptimizerMixin(object): + def __init__(self, *args, **kwargs): + # forward __init__ call to the next class in MRO (method resolution order) + super().__init__(*args, **kwargs) + self._multiply_factor = 1.0 + + @property + def has_flat_params(self): + return False + + def state_dict(self): + """Return the optimizer's state dict.""" + state_dict = self.wrapped_optimizer.state_dict() + if self.scaler is not None: + state_dict["loss_scale"] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + if "loss_scale" in state_dict and self.scaler is not None: + self.scaler.loss_scale = state_dict["loss_scale"] + + self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides) + + # Hack: PyTorch automatically casts the optimizer state to match the + # type of the current parameters. But with --memory-efficient-fp16 the + # params are FP16 while the optimizer state is FP32 and we don't want + # to cast. A workaround is to manually copy back the original state + # after the optimizer has been loaded. + if not getattr(self.optimizer, "disable_mem_eff_fp16_loading_hack", False): + groups = self.optimizer.param_groups + saved_groups = state_dict["param_groups"] + id_map = { + old_id: p + for old_id, p in zip( + chain(*(g["params"] for g in saved_groups)), + chain(*(g["params"] for g in groups)), + ) + } + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + self.optimizer.state[param] = v + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + if self.scaler is not None: + loss = self.scaler.scale(loss) + loss.backward() + + def _unscale_grads(self): + if self._multiply_factor != 1.0: + self.wrapped_optimizer.multiply_grads(self._multiply_factor) + self._multiply_factor = 1.0 + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + self._multiply_factor *= c + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm and updates dynamic loss scaler.""" + max_norm = float(max_norm) + grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm( + 0, aggregate_norm_fn + ) + + if self.scaler is not None: + grad_norm_cpu = float(grad_norm) + if grad_norm_cpu > max_norm > 0.0: + self._multiply_factor *= max_norm / grad_norm_cpu + + # detect overflow and adjust loss scale + self.scaler.check_overflow(grad_norm_cpu) + elif max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) + self._multiply_factor *= clip_coef + + return grad_norm + + def step(self, closure=None): + """Performs a single optimization step.""" + if getattr(self, "supports_step_with_scale", False): + # NOTE(msb) optimizer divides by scale factor + self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor)) + else: + self._unscale_grads() + self.wrapped_optimizer.step(closure) + + if self.scaler is not None: + self.scaler.update() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + self.wrapped_optimizer.zero_grad() + if self.scaler is not None: + self._multiply_factor = 1.0 / float(self.scaler.loss_scale) + else: + self._multiply_factor = 1.0 + + +class MemoryEfficientFP16Optimizer( + _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer +): + """ + Wrap an *optimizer* to support FP16 (mixed precision) training. + + Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not + maintain an FP32 copy of the model. We instead expect the optimizer to + convert the gradients to FP32 internally and sync the results back to the + FP16 model params. This significantly reduces memory usage but slightly + increases the time spent in the optimizer. + + Since this wrapper depends on specific functionality in the wrapped + optimizer (i.e., on-the-fly conversion of grads to FP32), only certain + optimizers can be wrapped. This is determined by the + *supports_memory_efficient_fp16* property. + """ + + def __init__(self, cfg: DictConfig, params, optimizer, **kwargs): + if not optimizer.supports_memory_efficient_fp16: + raise ValueError( + "Unsupported optimizer: {}".format(optimizer.__class__.__name__) + ) + + super().__init__(cfg.optimizer) + self.wrapped_optimizer = optimizer + + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: + raise ValueError( + "--fp16-scale-window must be given explicitly when using a " + "custom --update-freq schedule" + ) + data_parallel_size = int( + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = ( + 2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0] + ) + else: + scale_window = cfg.common.fp16_scale_window + + if not getattr(cfg.common, "bf16", False): + self.scaler = DynamicLossScaler( + init_scale=cfg.common.fp16_init_scale, + scale_window=scale_window, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, + ) + else: + # disable loss scaling for bfloat16 + self.scaler = None + + @classmethod + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): + """ + Args: + args (argparse.Namespace): fairseq args + params (iterable): iterable of parameters to optimize + """ + fp16_optimizer = optim.build_optimizer(cfg.optimizer, params) + return cls(cfg, params, fp16_optimizer, **kwargs) + + @property + def optimizer(self): + return self.wrapped_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.wrapped_optimizer.optimizer = optimizer + + @property + def optimizer_config(self): + return self.wrapped_optimizer.optimizer_config + + def get_lr(self): + return self.wrapped_optimizer.get_lr() + + def set_lr(self, lr): + self.wrapped_optimizer.set_lr(lr) diff --git a/fairseq/optim/fused_adam.py b/fairseq/optim/fused_adam.py new file mode 100644 index 0000000..1780f9c --- /dev/null +++ b/fairseq/optim/fused_adam.py @@ -0,0 +1,348 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import types + +import torch + + +def get_fused_adam_class(): + """ + Look for the FusedAdam optimizer from apex. We first try to load the + "contrib" interface, which is a bit faster than the main interface, + but is technically deprecated. + """ + try: + # The "deprecated" interface in recent versions of apex is a bit + # faster than the main interface, since we don't use the apex + # optimizer. This can be installed by passing the + # `--deprecated_fused_adam` option when building apex. + global fused_adam_cuda + import importlib + + fused_adam_cuda = importlib.import_module("fused_adam_cuda") + return FusedAdamV1 + except ImportError: + try: + # fallback to the newer interface + from apex.optimizers import FusedAdam as _FusedAdam # noqa + from apex.multi_tensor_apply import multi_tensor_applier + + if multi_tensor_applier.available: + return FusedAdamV2 + except ImportError: + pass + return None + + +class FusedAdamV1(torch.optim.Optimizer): + """ + Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via + ``python setup.py install --cuda_ext --cpp_ext``. + + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + + Compared to the original version in Apex, the fairseq version casts grads + and params to FP32 internally to support ``--memory-efficient-fp16``. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in FusedAdam! + eps_inside_sqrt (boolean, optional): in the 'update parameters' step, + adds eps to the bias-corrected second moment estimate before + evaluating square root instead of adding it to the square root of + second moment estimate as in the original paper. (default: False) + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + eps_inside_sqrt=False, + weight_decay=0.0, + max_grad_norm=0.0, + amsgrad=False, + ): + global fused_adam_cuda + import importlib + + fused_adam_cuda = importlib.import_module("fused_adam_cuda") + + if amsgrad: + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") + defaults = { + "lr": lr, + "bias_correction": bias_correction, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "max_grad_norm": max_grad_norm, + } + super().__init__(params, defaults) + self.eps_mode = 0 if eps_inside_sqrt else 1 + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + @property + def supports_step_with_scale(self): + return True + + def step(self, closure=None, grads=None, scale=1.0, grad_norms=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + grads (list of tensors, optional): weight gradient to use for the + optimizer update. If gradients have type torch.half, parameters + are expected to be in type torch.float. (default: None) + output params (list of tensors, optional): A reduced precision copy + of the updated weights written out in addition to the regular + updated weights. Have to be of same type as gradients. (default: None) + scale (float, optional): factor to divide gradient tensor values + by before applying to weights. (default: 1) + """ + loss = None + if closure is not None: + loss = closure() + + if grads is None: + grads_group = [None] * len(self.param_groups) + # backward compatibility + # assuming a list/generator of parameter means single group + elif isinstance(grads, types.GeneratorType): + grads_group = [grads] + elif type(grads[0]) != list: + grads_group = [grads] + else: + grads_group = grads + + if grad_norms is None: + grad_norms = [None] * len(self.param_groups) + + for group, grads_this_group, grad_norm in zip( + self.param_groups, grads_group, grad_norms + ): + if grads_this_group is None: + grads_this_group = [None] * len(group["params"]) + + # compute combined scale factor for this group + combined_scale = scale + if group.get("max_grad_norm", 0) > 0: + # norm is in fact norm*scale + clip = ((grad_norm / scale) + 1e-6) / group["max_grad_norm"] + if clip > 1: + combined_scale = clip * scale + + bias_correction = 1 if group.get("bias_correction", 1) else 0 + + for p, grad in zip(group["params"], grads_this_group): + # note: p.grad should not ever be set for correct + # operation of mixed precision optimizer that sometimes + # sends None gradients + if p.grad is None and grad is None: + continue + if grad is None: + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "FusedAdam does not support sparse gradients, " + "please consider SparseAdam instead" + ) + + p_data_fp32 = p.data.float() + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p_data_fp32) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].to(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + out_p = p.data + with torch.cuda.device(p.device): + fused_adam_cuda.adam( + p_data_fp32, + out_p, + exp_avg, + exp_avg_sq, + grad, + group["lr"], + beta1, + beta2, + group["eps"], + combined_scale, + state["step"], + self.eps_mode, + bias_correction, + group["weight_decay"], + ) + + return loss + + +try: + from apex.optimizers import FusedAdam + from apex.multi_tensor_apply import multi_tensor_applier + + class FusedAdamV2(FusedAdam): + """ + Compared to the original version in Apex, the fairseq version casts grads + and params to FP32 internally to support ``--memory-efficient-fp16``. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not hasattr(self, "multi_tensor_adam"): + raise Exception( + "Apex installation is outdated. Please install an updated version of apex." + ) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + def step( + self, + closure=None, + grads=None, + output_params=None, + scale=None, + grad_norms=None, + ): + """Performs a single optimization step.""" + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if "step" in group: + group["step"] += 1 + else: + group["step"] = 1 + + # create lists for multi-tensor apply + g_16, p_16, orig_p_16, m_16, v_16 = [], [], [], [], [] + g_32, p_32, m_32, v_32 = [], [], [], [] + + for p in group["params"]: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError( + "FusedAdam does not support sparse gradients, " + "please consider SparseAdam instead" + ) + + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data, dtype=torch.float) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p.data, dtype=torch.float + ) + else: + state["exp_avg"] = state["exp_avg"].to( + device=p.data.device, dtype=torch.float + ) + state["exp_avg_sq"] = state["exp_avg_sq"].to( + device=p.data.device, dtype=torch.float + ) + + if p.dtype == torch.float16: + g_16.append(p.grad.data.float()) + p_16.append(p.data.float()) + orig_p_16.append(p.data) + m_16.append(state["exp_avg"]) + v_16.append(state["exp_avg_sq"]) + elif p.dtype == torch.float32: + g_32.append(p.grad.data) + p_32.append(p.data) + m_32.append(state["exp_avg"]) + v_32.append(state["exp_avg_sq"]) + else: + raise RuntimeError("FusedAdam only support fp16 and fp32.") + + with torch.cuda.device(p.device): + if len(g_16) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + for orig_p, p in zip(orig_p_16, p_16): + orig_p.copy_(p.data) + if len(g_32) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + + return loss + + +except ImportError: + pass diff --git a/fairseq/optim/fused_lamb.py b/fairseq/optim/fused_lamb.py new file mode 100644 index 0000000..f4f2bdb --- /dev/null +++ b/fairseq/optim/fused_lamb.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.optim import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("lamb") +class FairseqLAMB(LegacyFairseqOptimizer): + """LAMB optimizer.""" + + def __init__(self, args, params): + super().__init__(args) + try: + from apex.optimizers import FusedLAMB + + self._optimizer = FusedLAMB(params, **self.optimizer_config) + except ImportError: + raise ImportError("Please install apex to use LAMB optimizer") + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B', + help='betas for LAMB optimizer') + parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D', + help='epsilon for LAMB optimizer') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "betas": eval(self.args.lamb_betas), + "eps": self.args.lamb_eps, + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + return False diff --git a/fairseq/optim/lr_scheduler/__init__.py b/fairseq/optim/lr_scheduler/__init__.py new file mode 100644 index 0000000..f07d43c --- /dev/null +++ b/fairseq/optim/lr_scheduler/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os + +from fairseq import registry +from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa + FairseqLRScheduler, + LegacyFairseqLRScheduler, +) +from omegaconf import DictConfig + + +( + build_lr_scheduler_, + register_lr_scheduler, + LR_SCHEDULER_REGISTRY, + LR_SCHEDULER_DATACLASS_REGISTRY, +) = registry.setup_registry( + "--lr-scheduler", base_class=FairseqLRScheduler, default="fixed" +) + + +def build_lr_scheduler(cfg: DictConfig, optimizer): + return build_lr_scheduler_(cfg, optimizer) + + +# automatically import any Python files in the optim/lr_scheduler/ directory +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.optim.lr_scheduler." + file_name) diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py new file mode 100644 index 0000000..c3c6663 --- /dev/null +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -0,0 +1,154 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections import Collection +from dataclasses import dataclass, field +from typing import List + +from fairseq.dataclass import FairseqDataclass +from omegaconf import II, DictConfig + +from . import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class CosineConfig(FairseqDataclass): + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is args.lr" + }, + ) + max_lr: float = field( + default=1.0, metadata={"help": "max learning rate, must be more than args.lr"} + ) + t_mult: float = field( + default=1.0, metadata={"help": "factor to grow the length of each period"} + ) + lr_period_updates: float = field( + default=-1, metadata={"help": "initial number of updates per period"} + ) + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + # TODO common var for parent class + lr: List[float] = II("optimization.lr") + max_update: int = II("optimization.max_update") + + +@register_lr_scheduler("cosine", dataclass=CosineConfig) +class CosineSchedule(FairseqLRScheduler): + """Assign LR based on a cyclical schedule that follows the cosine function. + + See https://arxiv.org/pdf/1608.03983.pdf for details. + + We also support a warmup phase where we linearly increase the learning rate + from some initial learning rate (``--warmup-init-lr``) until the configured + max learning rate (``--max-lr``). + + During warmup:: + + lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) + lr = lrs[update_num] + + After warmup:: + + lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) + + where ``t_curr`` is current percentage of updates within the current period + range and ``t_i`` is the current period range, which is scaled by ``t_mul`` + after every iteration. + """ + + def __init__( + self, cfg: DictConfig, fairseq_optimizer + ): + super().__init__(cfg, fairseq_optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with cosine." + " Consider --lr-scheduler=fixed instead." + ) + + warmup_end_lr = cfg.max_lr + lr = ( + cfg.lr[0] + if isinstance(cfg.lr, Collection) + else cfg.lr + ) + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = lr + + self.min_lr = lr + self.max_lr = cfg.max_lr + assert self.max_lr > self.min_lr, "max_lr must be more than lr" + + self.t_mult = cfg.t_mult + self.period = cfg.lr_period_updates + + if self.period <= 0: + assert ( + cfg.max_update >= 0 + ), "Either --max_update or --lr-period-updates must be set" + self.period = cfg.max_update - cfg.warmup_updates + + if cfg.warmup_updates > 0: + # linearly warmup for the first args.warmup_updates + self.lr_step = ( + warmup_end_lr - cfg.warmup_init_lr + ) / cfg.warmup_updates + else: + self.lr_step = 1 + + self.warmup_updates = cfg.warmup_updates + self.lr_shrink = cfg.lr_shrink + + # initial learning rate + self.lr = cfg.warmup_init_lr + self.optimizer.set_lr(self.lr) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step + else: + curr_updates = num_updates - self.cfg.warmup_updates + if self.t_mult != 1: + i = math.floor( + math.log( + 1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult + ) + ) + t_i = self.t_mult ** i * self.period + t_curr = ( + curr_updates + - (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period + ) + else: + i = math.floor(curr_updates / self.period) + t_i = self.period + t_curr = curr_updates - (self.period * i) + + lr_shrink = self.lr_shrink ** i + min_lr = self.min_lr * lr_shrink + max_lr = self.max_lr * lr_shrink + + self.lr = min_lr + 0.5 * (max_lr - min_lr) * ( + 1 + math.cos(math.pi * t_curr / t_i) + ) + + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py new file mode 100644 index 0000000..569e448 --- /dev/null +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace + +from fairseq.dataclass.utils import gen_parser_from_dataclass + +from .. import FairseqOptimizer + + +class FairseqLRScheduler(object): + def __init__(self, cfg, optimizer): + super().__init__() + if not isinstance(optimizer, FairseqOptimizer): + raise ValueError("optimizer must be an instance of FairseqOptimizer") + self.cfg = cfg + self.optimizer = optimizer + self.best = None + + @classmethod + def add_args(cls, parser): + """Add arguments to the parser for this LR scheduler.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {"best": self.best} + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + self.best = state_dict["best"] + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + if val_loss is not None: + if self.best is None: + self.best = val_loss + else: + self.best = min(self.best, val_loss) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return self.optimizer.get_lr() + + +class LegacyFairseqLRScheduler(FairseqLRScheduler): + def __init__(self, args: Namespace, optimizer): + if not isinstance(optimizer, FairseqOptimizer): + raise ValueError("optimizer must be an instance of FairseqOptimizer") + self.args = args + self.optimizer = optimizer + self.best = None diff --git a/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/optim/lr_scheduler/fixed_schedule.py new file mode 100644 index 0000000..7ca7826 --- /dev/null +++ b/fairseq/optim/lr_scheduler/fixed_schedule.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import LegacyFairseqLRScheduler, register_lr_scheduler + + +@register_lr_scheduler("fixed") +class FixedSchedule(LegacyFairseqLRScheduler): + """Decay the LR on a fixed schedule.""" + + def __init__(self, args, optimizer): + super().__init__(args, optimizer) + + # set defaults + args.warmup_updates = getattr(args, "warmup_updates", 0) or 0 + + self.lr = args.lr[0] + if args.warmup_updates > 0: + self.warmup_factor = 1.0 / args.warmup_updates + else: + self.warmup_factor = 1 + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', + help='force annealing at specified epoch') + parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', + help='shrink factor for annealing, lr_new = (lr * lr_shrink)') + parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', + help='warmup the learning rate linearly for the first N updates') + # fmt: on + + def state_dict(self): + return {"lr": self.lr} + + def load_state_dict(self, state_dict): + if "lr" in state_dict: + self.lr = state_dict["lr"] + + def get_next_lr(self, epoch): + lrs = self.args.lr + if self.args.force_anneal is None or epoch < self.args.force_anneal: + # use fixed LR schedule + next_lr = lrs[min(epoch, len(lrs) - 1)] + else: + # annneal based on lr_shrink + next_lr = lrs[-1] * self.args.lr_shrink ** ( + epoch + 1 - self.args.force_anneal + ) + return next_lr + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.warmup_factor * self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if self.args.warmup_updates > 0 and num_updates < self.args.warmup_updates: + self.warmup_factor = (num_updates + 1) / float(self.args.warmup_updates) + self.optimizer.set_lr(self.warmup_factor * self.lr) + else: + self.optimizer.set_lr(self.lr) + return self.optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py new file mode 100644 index 0000000..c42e090 --- /dev/null +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import Collection +from dataclasses import dataclass, field +from typing import List + +from fairseq.dataclass import FairseqDataclass +from omegaconf import II, DictConfig + +from . import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class InverseSquareRootScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=4000, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is args.lr" + }, + ) + # TODO common vars at parent class + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig) +class InverseSquareRootSchedule(FairseqLRScheduler): + """Decay the LR based on the inverse square root of the update number. + + We also support a warmup phase where we linearly increase the learning rate + from some initial learning rate (``--warmup-init-lr``) until the configured + learning rate (``--lr``). Thereafter we decay proportional to the number of + updates, with a decay factor set to align with the configured learning rate. + + During warmup:: + + lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) + lr = lrs[update_num] + + After warmup:: + + decay_factor = args.lr * sqrt(args.warmup_updates) + lr = decay_factor / sqrt(update_num) + """ + + def __init__(self, cfg: DictConfig, optimizer): + super().__init__(cfg, optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with inverse_sqrt." + " Consider --lr-scheduler=fixed instead." + ) + warmup_end_lr = ( + cfg.lr[0] + if isinstance(cfg.lr, Collection) + else cfg.lr + ) + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = ( + 0 if cfg.warmup_updates > 0 else warmup_end_lr + ) + + # linearly warmup for the first args.warmup_updates + self.lr_step = ( + warmup_end_lr - cfg.warmup_init_lr + ) / cfg.warmup_updates + + # then, decay prop. to the inverse square root of the update number + self.decay_factor = warmup_end_lr * cfg.warmup_updates ** 0.5 + + # initial learning rate + self.lr = cfg.warmup_init_lr + self.optimizer.set_lr(self.lr) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step + else: + self.lr = self.decay_factor * num_updates ** -0.5 + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py new file mode 100644 index 0000000..ea8e647 --- /dev/null +++ b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py @@ -0,0 +1,82 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import LegacyFairseqLRScheduler, register_lr_scheduler + + +@register_lr_scheduler("polynomial_decay") +class PolynomialDecaySchedule(LegacyFairseqLRScheduler): + """Decay the LR on a fixed schedule.""" + + def __init__(self, args, optimizer): + super().__init__(args, optimizer) + + # set defaults + args.warmup_updates = getattr(args, "warmup_updates", 0) or 0 + + self.lr = args.lr[0] + if args.warmup_updates > 0: + self.warmup_factor = 1.0 / args.warmup_updates + else: + self.warmup_factor = 1 + self.end_learning_rate = args.end_learning_rate + self.total_num_update = args.total_num_update + self.power = args.power + self.optimizer.set_lr(self.warmup_factor * self.lr) + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + parser.add_argument( + "--force-anneal", + "--fa", + type=int, + metavar="N", + help="force annealing at specified epoch", + ) + parser.add_argument( + "--warmup-updates", + default=0, + type=int, + metavar="N", + help="warmup the learning rate linearly for the first N updates", + ) + parser.add_argument("--end-learning-rate", default=0.0, type=float) + parser.add_argument("--power", default=1.0, type=float) + parser.add_argument("--total-num-update", default=1000000, type=int) + + def get_next_lr(self, epoch): + lrs = self.args.lr + if self.args.force_anneal is None or epoch < self.args.force_anneal: + # use fixed LR schedule + next_lr = lrs[min(epoch, len(lrs) - 1)] + else: + # annneal based on lr_shrink + next_lr = self.optimizer.get_lr() + return next_lr + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.warmup_factor * self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: + self.warmup_factor = num_updates / float(self.args.warmup_updates) + lr = self.warmup_factor * self.lr + elif num_updates >= self.total_num_update: + lr = self.end_learning_rate + else: + warmup = self.args.warmup_updates + lr_range = self.lr - self.end_learning_rate + pct_remaining = 1 - (num_updates - warmup) / ( + self.total_num_update - warmup + ) + lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate + self.optimizer.set_lr(lr) + return self.optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py new file mode 100644 index 0000000..82bb36e --- /dev/null +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.optim.lr_scheduler + +from . import LegacyFairseqLRScheduler, register_lr_scheduler + + +@register_lr_scheduler("reduce_lr_on_plateau") +class ReduceLROnPlateau(LegacyFairseqLRScheduler): + """ + Decay the LR by a factor every time the validation loss plateaus. + Also comes with optional warmup phase, where we linearly increase + the learning rate from some initial learning rate + (``--warmup-init-lr``) until the configured learning rate + (``--lr``). Thereafter the lr is adjusted according to original + reduce_on_plateau scheme. + + During warmup:: + + lrs = torch.linspace( + args.warmup_init_lr, args.lr, args.warmup_updates + ) + lr = lrs[update_num] + """ + + def __init__(self, args, optimizer): + super().__init__(args, optimizer) + if len(args.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with reduce_lr_on_plateau." + " Consider --lr-scheduler=fixed instead." + ) + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer.optimizer, + patience=args.lr_patience, + factor=args.lr_shrink, + mode="max" if args.maximize_best_checkpoint_metric else "min", + threshold=args.lr_threshold, + ) + warmup_end_lr = args.lr[0] + # if no warm up, sets initial lr to be args.lr[0] + if args.warmup_init_lr < 0: + args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr + + # linearly warmup for the first args.warmup_updates + if args.warmup_updates > 0: + self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + + # this flag is either set from arg when no warm up, or set by + # step_update() when warmup finishes + self.warmup_end = True if args.warmup_updates <= 0 else False + + # initial learning rate + # this self.lr is used only during init and/or warm up period + self.lr = args.warmup_init_lr + self.optimizer.set_lr(self.lr) + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', + help='shrink factor for annealing, lr_new = (lr * lr_shrink)') + parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT', + help='threshold for measuring the new optimum, ' + 'to only focus on significant changes') + parser.add_argument('--lr-patience', default=0, type=int, + help='number of epochs with no improvement after which ' + 'learning rate will be reduced') + parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', + help='warmup the learning rate linearly for the first N updates') + parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', + help='initial learning rate during warmup phase; default is args.lr') + # fmt: on + + def state_dict(self): + """Return the LR scheduler state dict.""" + return { + "best": self.lr_scheduler.best, + "last_epoch": self.lr_scheduler.last_epoch, + } + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + self.lr_scheduler.best = state_dict["best"] + if "last_epoch" in state_dict: + self.lr_scheduler.last_epoch = state_dict["last_epoch"] + + def step(self, epoch, val_loss=None): + """ + Update the learning rate at the end of the given epoch if warmup + finishes otherwise no update of lr on epoch boundaries + """ + if val_loss is not None and self.warmup_end is True: + self.lr_scheduler.step(val_loss) + else: + self.lr_scheduler.last_epoch = epoch + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """ + Update the learning rate after each update.""" + # if there is warmup + if self.args.warmup_updates > 0: + if num_updates <= self.args.warmup_updates: + self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + self.optimizer.set_lr(self.lr) + else: + if self.warmup_end is False: + self.warmup_end = True + # else do nothing + return self.optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py new file mode 100644 index 0000000..c573237 --- /dev/null +++ b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -0,0 +1,165 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from . import LegacyFairseqLRScheduler, register_lr_scheduler + + +@register_lr_scheduler("tri_stage") +class TriStageLRSchedule(LegacyFairseqLRScheduler): + """Tristage learning rate schedulr + + Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf + + Similar to inverse_squre_root scheduler, but tri_stage learning rate employs + three stages LR scheduling: + + - warmup stage, starting from `lr` * `init_lr_scale`, linearly + increased to `lr` in `warmup_steps` iterations + + - hold stage, after `warmup_steps`, keep the LR as `lr` for `hold_steps` + iterations + + - decay stage, after hold stage, decay LR exponetially to + `lr` * `final_lr_scale` in `decay_steps`; + after that LR is keep as `final_lr_scale` * `lr` + + During warmup:: + + init_lr = args.init_lr_scale * args.lr + lrs = torch.linspace(init_lr, args.lr, args.warmup_steps) + lr = lrs[update_num] + + During hold:: + + lr = args.lr + + During decay:: + + decay_factor = - math.log(args.final_lr_scale) / args.decay_steps + lr = args.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor) + + After that:: + + lr = args.lr * args.final_lr_scale + """ + + def __init__(self, args, optimizer): + super().__init__(args, optimizer) + if len(args.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with tri-stage lr." + " Consider --lr-scheduler=fixed instead." + ) + + # calculate LR at each point + self.peak_lr = args.lr[0] + self.init_lr = args.init_lr_scale * args.lr[0] + self.final_lr = args.final_lr_scale * args.lr[0] + + # remember the steps at each stage + self.warmup_steps = args.warmup_steps + self.hold_steps = args.hold_steps + self.decay_steps = args.decay_steps + + self.warmup_rate = ( + (self.peak_lr - self.init_lr) / self.warmup_steps + if self.warmup_steps != 0 + else 0 + ) + self.decay_factor = -math.log(args.final_lr_scale) / args.decay_steps + + # initial learning rate + self.lr = self.init_lr + self.optimizer.set_lr(self.lr) + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument( + '--warmup-steps', + default=4000, + type=int, + metavar='N', + help='warmup the learning rate linearly for the first N updates' + ) + parser.add_argument( + '--hold-steps', + default=20000, + type=int, + metavar='N', + help='steps in hold stage.' + ) + parser.add_argument( + '--decay-steps', + default=60000, + type=int, + metavar='N', + help='steps in decay stages' + ) + parser.add_argument( + '--init-lr-scale', + default=0.01, + type=float, + help=""" + initial learning rate scale during warmup phase; default is 0.01""") + parser.add_argument( + '--final-lr-scale', + default=0.01, + type=float, + help="final learning rate scale; default to 0.01" + ) + # fmt: on + + def _decide_stage(self, update_step): + """ + return stage, and the corresponding steps within the current stage + """ + if update_step < self.warmup_steps: + # warmup state + return 0, update_step + + offset = self.warmup_steps + + if update_step < offset + self.hold_steps: + # hold stage + return 1, update_step - offset + + offset += self.hold_steps + + if update_step <= offset + self.decay_steps: + # decay stage + return 2, update_step - offset + + offset += self.decay_steps + + # still here ? constant lr stage + return 3, update_step - offset + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + stage, steps_in_stage = self._decide_stage(num_updates) + if stage == 0: + self.lr = self.init_lr + self.warmup_rate * steps_in_stage + elif stage == 1: + self.lr = self.peak_lr + elif stage == 2: + self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage) + elif stage == 3: + self.lr = self.final_lr + else: + raise ValueError("Undefined stage") + + self.optimizer.set_lr(self.lr) + + return self.lr diff --git a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py new file mode 100644 index 0000000..0f3193f --- /dev/null +++ b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from . import LegacyFairseqLRScheduler, register_lr_scheduler + + +@register_lr_scheduler("triangular") +class TriangularSchedule(LegacyFairseqLRScheduler): + """Assign LR based on a triangular cyclical schedule. + + See https://arxiv.org/pdf/1506.01186.pdf for details. + """ + + def __init__(self, args, optimizer): + super().__init__(args, optimizer) + if len(args.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with triangular." + " Consider --lr-scheduler=fixed instead." + ) + + lr = args.lr[0] + + assert args.max_lr > lr, "max_lr must be more than lr" + self.min_lr = lr + self.max_lr = args.max_lr + self.stepsize = args.lr_period_updates // 2 + self.lr_shrink = args.lr_shrink + self.shrink_min = args.shrink_min + + # initial learning rate + self.lr = self.min_lr + self.optimizer.set_lr(self.lr) + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument('--max-lr', required=True, type=float, metavar='LR', + help='max learning rate, must be more than args.lr') + parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', + help='initial number of updates per period (cycle length)') + parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', + help='shrink factor for annealing') + parser.add_argument('--shrink-min', action='store_true', + help='if set, also shrinks min lr') + # fmt: on + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + cycle = math.floor(num_updates / (2 * self.stepsize)) + + lr_shrink = self.lr_shrink ** cycle + max_lr = self.max_lr * lr_shrink + if self.shrink_min: + min_lr = self.min_lr * lr_shrink + else: + min_lr = self.min_lr + + x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1) + self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x)) + + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py new file mode 100644 index 0000000..3982a82 --- /dev/null +++ b/fairseq/optim/nag.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import Collection +from dataclasses import dataclass, field +from typing import List + +import torch +from fairseq.dataclass import FairseqDataclass +from omegaconf import II, DictConfig +from torch.optim.optimizer import Optimizer, required + +from . import FairseqOptimizer, register_optimizer + + +@dataclass +class FairseqNAGConfig(FairseqDataclass): + momentum: float = field(default=0.99, metadata={"help": "momentum factor"}) + weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) + # TODO common vars in parent class + lr: List[float] = II("optimization.lr") + + +@register_optimizer("nag", dataclass=FairseqNAGConfig) +class FairseqNAG(FairseqOptimizer): + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) + self._optimizer = NAG(params, **self.optimizer_config) + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "momentum": self.cfg.momentum, + "weight_decay": self.cfg.weight_decay, + } + + +class NAG(Optimizer): + def __init__(self, params, lr=required, momentum=0, weight_decay=0): + defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) + super(NAG, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + weight_decay = group["weight_decay"] + momentum = group["momentum"] + lr = group["lr"] + lr_old = group.get("lr_old", lr) + lr_correct = lr / lr_old + + for p in group["params"]: + if p.grad is None: + continue + + p_data_fp32 = p.data + if p_data_fp32.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + d_p = p.grad.data.float() + param_state = self.state[p] + if "momentum_buffer" not in param_state: + param_state["momentum_buffer"] = torch.zeros_like(d_p) + else: + param_state["momentum_buffer"] = param_state["momentum_buffer"].to( + d_p + ) + + buf = param_state["momentum_buffer"] + + if weight_decay != 0: + p_data_fp32.mul_(1 - lr * weight_decay) + p_data_fp32.add_(buf, alpha=momentum * momentum * lr_correct) + p_data_fp32.add_(d_p, alpha=-(1 + momentum) * lr) + + buf.mul_(momentum * lr_correct).add_(d_p, alpha=-lr) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + group["lr_old"] = lr + + return loss diff --git a/fairseq/optim/sgd.py b/fairseq/optim/sgd.py new file mode 100644 index 0000000..8e34fb9 --- /dev/null +++ b/fairseq/optim/sgd.py @@ -0,0 +1,43 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("sgd") +class SGD(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = torch.optim.SGD(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--momentum', default=0.0, type=float, metavar='M', + help='momentum factor') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "momentum": self.args.momentum, + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + return True diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py new file mode 100644 index 0000000..ecef05b --- /dev/null +++ b/fairseq/optim/shard.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +try: + from fairscale.optim import OSS + + _has_fairscale = True +except ImportError: + _has_fairscale = False + + +def shard_(optimizer, group): + if not _has_fairscale: + raise ImportError( + "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" + ) + + class FairseqOSS(OSS): + @property + def disable_mem_eff_fp16_loading_hack(self): + return True + + def __getattr__(self, name): + if name.startswith("supports") and hasattr(self.optim, name): + return getattr(self.optim, name) + raise AttributeError( + "'FairseqOSS' object has no attribute {0!r}".format(name) + ) + + torch_optimizer = optimizer.optimizer + optim_cls = type(torch_optimizer) + + optimizer.optimizer = FairseqOSS( + torch_optimizer.param_groups, + optim_cls, + group=group, + **optimizer.optimizer_config + ) diff --git a/fairseq/options.py b/fairseq/options.py new file mode 100644 index 0000000..b79443a --- /dev/null +++ b/fairseq/options.py @@ -0,0 +1,363 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from typing import Callable, List, Optional + +import torch +from fairseq import utils +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.dataclass.configs import ( + CheckpointConfig, + CommonConfig, + CommonEvalConfig, + DatasetConfig, + DistributedTrainingConfig, + EvalLMConfig, + GenerationConfig, + InteractiveConfig, + OptimizationConfig, +) +from fairseq.dataclass.utils import gen_parser_from_dataclass + +# this import is for backward compatibility +from fairseq.utils import csv_str_list, eval_bool, eval_str_dict, eval_str_list # noqa + + +def get_preprocessing_parser(default_task="translation"): + parser = get_parser("Preprocessing", default_task) + add_preprocess_args(parser) + return parser + + +def get_training_parser(default_task="translation"): + parser = get_parser("Trainer", default_task) + add_dataset_args(parser, train=True) + add_distributed_training_args(parser) + add_model_args(parser) + add_optimization_args(parser) + add_checkpoint_args(parser) + return parser + + +def get_generation_parser(interactive=False, default_task="translation"): + parser = get_parser("Generation", default_task) + add_dataset_args(parser, gen=True) + add_distributed_training_args(parser, default_world_size=1) + add_generation_args(parser) + add_checkpoint_args(parser) + if interactive: + add_interactive_args(parser) + return parser + + +def get_interactive_generation_parser(default_task="translation"): + return get_generation_parser(interactive=True, default_task=default_task) + + +def get_eval_lm_parser(default_task="language_modeling"): + parser = get_parser("Evaluate Language Model", default_task) + add_dataset_args(parser, gen=True) + add_distributed_training_args(parser, default_world_size=1) + add_eval_lm_args(parser) + return parser + + +def get_validation_parser(default_task=None): + parser = get_parser("Validation", default_task) + add_dataset_args(parser, train=True) + add_distributed_training_args(parser, default_world_size=1) + group = parser.add_argument_group("Evaluation") + gen_parser_from_dataclass(group, CommonEvalConfig()) + return parser + + +def parse_args_and_arch( + parser: argparse.ArgumentParser, + input_args: List[str] = None, + parse_known: bool = False, + suppress_defaults: bool = False, + modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None, +): + """ + Args: + parser (ArgumentParser): the parser + input_args (List[str]): strings to parse, defaults to sys.argv + parse_known (bool): only parse known arguments, similar to + `ArgumentParser.parse_known_args` + suppress_defaults (bool): parse while ignoring all default values + modify_parser (Optional[Callable[[ArgumentParser], None]]): + function to modify the parser, e.g., to set default values + """ + if suppress_defaults: + # Parse args without any default values. This requires us to parse + # twice, once to identify all the necessary task/model args, and a second + # time with all defaults set to None. + args = parse_args_and_arch( + parser, + input_args=input_args, + parse_known=parse_known, + suppress_defaults=False, + ) + suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser]) + suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()}) + args = suppressed_parser.parse_args(input_args) + return argparse.Namespace( + **{k: v for k, v in vars(args).items() if v is not None} + ) + + from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY + + # Before creating the true parser, we need to import optional user module + # in order to eagerly import custom tasks, optimizers, architectures, etc. + usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + usr_parser.add_argument("--user-dir", default=None) + usr_args, _ = usr_parser.parse_known_args(input_args) + utils.import_user_module(usr_args) + + if modify_parser is not None: + modify_parser(parser) + + # The parser doesn't know about model/criterion/optimizer-specific args, so + # we parse twice. First we parse the model/criterion/optimizer, then we + # parse a second time after adding the *-specific arguments. + # If input_args is given, we will parse those args instead of sys.argv. + args, _ = parser.parse_known_args(input_args) + + # Add model-specific args to parser. + if hasattr(args, "arch"): + model_specific_group = parser.add_argument_group( + "Model-specific configuration", + # Only include attributes which are explicitly given as command-line + # arguments or which have default values. + argument_default=argparse.SUPPRESS, + ) + if args.arch in ARCH_MODEL_REGISTRY: + ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) + elif args.arch in MODEL_REGISTRY: + MODEL_REGISTRY[args.arch].add_args(model_specific_group) + else: + raise RuntimeError() + + if hasattr(args, "task"): + from fairseq.tasks import TASK_REGISTRY + + TASK_REGISTRY[args.task].add_args(parser) + if getattr(args, "use_bmuf", False): + # hack to support extra args for block distributed data parallelism + from fairseq.optim.bmuf import FairseqBMUF + + FairseqBMUF.add_args(parser) + + # Add *-specific args to parser. + from fairseq.registry import REGISTRIES + + for registry_name, REGISTRY in REGISTRIES.items(): + choice = getattr(args, registry_name, None) + if choice is not None: + cls = REGISTRY["registry"][choice] + if hasattr(cls, "add_args"): + cls.add_args(parser) + elif hasattr(cls, "__dataclass"): + gen_parser_from_dataclass(parser, cls.__dataclass()) + + # Modify the parser a second time, since defaults may have been reset + if modify_parser is not None: + modify_parser(parser) + + # Parse a second time. + if parse_known: + args, extra = parser.parse_known_args(input_args) + else: + args = parser.parse_args(input_args) + extra = None + # Post-process args. + if ( + hasattr(args, "batch_size_valid") and args.batch_size_valid is None + ) or not hasattr(args, "batch_size_valid"): + args.batch_size_valid = args.batch_size + if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None: + args.max_tokens_valid = args.max_tokens + if getattr(args, "memory_efficient_fp16", False): + args.fp16 = True + if getattr(args, "memory_efficient_bf16", False): + args.bf16 = True + args.tpu = getattr(args, "tpu", False) + args.bf16 = getattr(args, "bf16", False) + if args.bf16: + args.tpu = True + if args.tpu and args.fp16: + raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs") + + if getattr(args, "seed", None) is None: + args.seed = 1 # default seed for training + args.no_seed_provided = True + else: + args.no_seed_provided = False + + # Apply architecture configuration. + if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY: + ARCH_CONFIG_REGISTRY[args.arch](args) + + if parse_known: + return args, extra + else: + return args + + +def get_parser(desc, default_task="translation"): + # Before creating the true parser, we need to import optional user module + # in order to eagerly import custom tasks, optimizers, architectures, etc. + usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + usr_parser.add_argument("--user-dir", default=None) + usr_args, _ = usr_parser.parse_known_args() + utils.import_user_module(usr_args) + + parser = argparse.ArgumentParser(allow_abbrev=False) + gen_parser_from_dataclass(parser, CommonConfig()) + + from fairseq.registry import REGISTRIES + + for registry_name, REGISTRY in REGISTRIES.items(): + parser.add_argument( + "--" + registry_name.replace("_", "-"), + default=REGISTRY["default"], + choices=REGISTRY["registry"].keys(), + ) + + # Task definitions can be found under fairseq/tasks/ + from fairseq.tasks import TASK_REGISTRY + + parser.add_argument( + "--task", + metavar="TASK", + default=default_task, + choices=TASK_REGISTRY.keys(), + help="task", + ) + # fmt: on + return parser + + +def add_preprocess_args(parser): + group = parser.add_argument_group("Preprocessing") + # fmt: off + group.add_argument("-s", "--source-lang", default=None, metavar="SRC", + help="source language") + group.add_argument("-t", "--target-lang", default=None, metavar="TARGET", + help="target language") + group.add_argument("--trainpref", metavar="FP", default=None, + help="train file prefix (also used to build dictionaries)") + group.add_argument("--validpref", metavar="FP", default=None, + help="comma separated, valid file prefixes " + "(words missing from train set are replaced with )") + group.add_argument("--testpref", metavar="FP", default=None, + help="comma separated, test file prefixes " + "(words missing from train set are replaced with )") + group.add_argument("--align-suffix", metavar="FP", default=None, + help="alignment file suffix") + group.add_argument("--destdir", metavar="DIR", default="data-bin", + help="destination dir") + group.add_argument("--thresholdtgt", metavar="N", default=0, type=int, + help="map words appearing less than threshold times to unknown") + group.add_argument("--thresholdsrc", metavar="N", default=0, type=int, + help="map words appearing less than threshold times to unknown") + group.add_argument("--tgtdict", metavar="FP", + help="reuse given target dictionary") + group.add_argument("--srcdict", metavar="FP", + help="reuse given source dictionary") + group.add_argument("--nwordstgt", metavar="N", default=-1, type=int, + help="number of target words to retain") + group.add_argument("--nwordssrc", metavar="N", default=-1, type=int, + help="number of source words to retain") + group.add_argument("--alignfile", metavar="ALIGN", default=None, + help="an alignment file (optional)") + parser.add_argument('--dataset-impl', metavar='FORMAT', default='mmap', + choices=get_available_dataset_impl(), + help='output dataset implementation') + group.add_argument("--joined-dictionary", action="store_true", + help="Generate joined dictionary") + group.add_argument("--only-source", action="store_true", + help="Only process the source language") + group.add_argument("--padding-factor", metavar="N", default=8, type=int, + help="Pad dictionary size to be multiple of N") + group.add_argument("--workers", metavar="N", default=1, type=int, + help="number of parallel workers") + # fmt: on + return parser + + +def add_dataset_args(parser, train=False, gen=False): + group = parser.add_argument_group("dataset_data_loading") + gen_parser_from_dataclass(group, DatasetConfig()) + # fmt: on + return group + + +def add_distributed_training_args(parser, default_world_size=None): + group = parser.add_argument_group("distributed_training") + if default_world_size is None: + default_world_size = max(1, torch.cuda.device_count()) + gen_parser_from_dataclass( + group, DistributedTrainingConfig(distributed_world_size=default_world_size) + ) + return group + + +def add_optimization_args(parser): + group = parser.add_argument_group("optimization") + # fmt: off + gen_parser_from_dataclass(group, OptimizationConfig()) + # fmt: on + return group + + +def add_checkpoint_args(parser): + group = parser.add_argument_group("checkpoint") + # fmt: off + gen_parser_from_dataclass(group, CheckpointConfig()) + # fmt: on + return group + + +def add_common_eval_args(group): + gen_parser_from_dataclass(group, CommonEvalConfig()) + + +def add_eval_lm_args(parser): + group = parser.add_argument_group("LM Evaluation") + add_common_eval_args(group) + gen_parser_from_dataclass(group, EvalLMConfig()) + + +def add_generation_args(parser): + group = parser.add_argument_group("Generation") + add_common_eval_args(group) + gen_parser_from_dataclass(group, GenerationConfig()) + return group + + +def add_interactive_args(parser): + group = parser.add_argument_group("Interactive") + gen_parser_from_dataclass(group, InteractiveConfig()) + + +def add_model_args(parser): + group = parser.add_argument_group("Model configuration") + # fmt: off + + # Model definitions can be found under fairseq/models/ + # + # The model architecture can be specified in several ways. + # In increasing order of priority: + # 1) model defaults (lowest priority) + # 2) --arch argument + # 3) --encoder/decoder-* arguments (highest priority) + from fairseq.models import ARCH_MODEL_REGISTRY + group.add_argument('--arch', '-a', metavar='ARCH', + choices=ARCH_MODEL_REGISTRY.keys(), + help='model architecture') + # fmt: on + return group diff --git a/fairseq/pdb.py b/fairseq/pdb.py new file mode 100644 index 0000000..1ba6ef0 --- /dev/null +++ b/fairseq/pdb.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import multiprocessing +import os +import pdb +import sys + + +__all__ = ["set_trace"] + + +_stdin = [None] +_stdin_lock = multiprocessing.Lock() +try: + _stdin_fd = sys.stdin.fileno() +except Exception: + _stdin_fd = None + + +class MultiprocessingPdb(pdb.Pdb): + """A Pdb wrapper that works in a multiprocessing environment. + + Usage: `from fairseq import pdb; pdb.set_trace()` + """ + + def __init__(self): + pdb.Pdb.__init__(self, nosigint=True) + + def _cmdloop(self): + stdin_bak = sys.stdin + with _stdin_lock: + try: + if _stdin_fd is not None: + if not _stdin[0]: + _stdin[0] = os.fdopen(_stdin_fd) + sys.stdin = _stdin[0] + self.cmdloop() + finally: + sys.stdin = stdin_bak + + +def set_trace(): + pdb = MultiprocessingPdb() + pdb.set_trace(sys._getframe().f_back) diff --git a/fairseq/quantization_utils.py b/fairseq/quantization_utils.py new file mode 100644 index 0000000..11fc414 --- /dev/null +++ b/fairseq/quantization_utils.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from fairseq.modules.quantization import pq, quantization_options, scalar +from omegaconf import DictConfig + + +logger = logging.getLogger(__name__) + + +def quantize_model_scalar(model, model_cfg: DictConfig): + quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0 + if quant_noise_scalar > 0: + # quantize_model edits the model in place + scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000) + return model + + +class Quantizer(object): + def __init__(self, config_path, max_epoch, max_update): + try: + import yaml + except ImportError: + raise ImportError("Please install yaml with: pip install yaml") + + # parse config + if config_path: + with open(config_path) as config_file: + config = quantization_options.parse_config_yaml( + yaml.safe_load(config_file) + ) + else: + config = quantization_options.parse_config_yaml({}) + + self.n_centroids_config = config["n_centroids"] + self.block_sizes_config = config["block_sizes"] + self.layers_to_quantize = config["layers_to_quantize"] + + # We assume that training will run for a fixed number of epochs + # (or updates) and that we should train for equal durations + # between iterations of PQ. + num_iterations = len(self.layers_to_quantize) + if max_epoch > 0: + assert max_epoch % num_iterations == 0, ( + "for iterative PQ, --max-epoch (={}) must be evenly divisible by " + "len(layers_to_quantize) (={})".format(max_epoch, num_iterations) + ) + self.epoch_schedule = max_epoch // num_iterations + else: + self.epoch_schedule = None + if max_update > 0: + assert max_update % num_iterations == 0, ( + "for iterative PQ, --max-update (={}) must be evenly divisible by " + "len(layers_to_quantize) (={})".format(max_update, num_iterations) + ) + self.update_schedule = max_update // num_iterations + else: + self.update_schedule = None + assert (self.epoch_schedule is not None) ^ ( + self.update_schedule is not None + ), "for iterative PQ, cannot specify both --max-update and --max-epoch" + + # 0 is a special value for quantization step, which will force + # the first call to begin_epoch() to call step() + self.quantization_step = 0 + + def set_trainer(self, trainer): + self.trainer = trainer + self.size_tracker = pq.SizeTracker(self.trainer.get_model()) + + def step(self): + """Move to the next stage of quantization.""" + if self.quantization_step >= len(self.layers_to_quantize): + # Maybe we just finished the last training step or we loaded + # a checkpoint for an iterative PQ model which previously + # finished training. Either way, don't quantize again. + return + + logger.info( + "quantizing model (step={}; layers_to_quantize[step]={})".format( + self.quantization_step, self.layers_to_quantize[self.quantization_step] + ) + ) + quantized_layers = pq.quantize_model_( + self.trainer.get_model(), + self.size_tracker, + self.layers_to_quantize, + self.block_sizes_config, + self.n_centroids_config, + step=self.quantization_step, + ) + logger.info("quantized layers: {}".format(quantized_layers)) + logger.info(self.size_tracker) + + self.quantization_step += 1 + + # reintialize the Trainer since model parameters have changed + self.trainer.reinitialize() + + def begin_epoch(self, epoch): + """Called at the beginning of each epoch (epochs start at 1).""" + if ( + ( + self.epoch_schedule is not None + and epoch > 0 + and (epoch - 1) % self.epoch_schedule == 0 + ) + # we always step once in the beginning, even if using + # update-based quantization + or self.quantization_step == 0 + ): + self.step() + + def step_update(self, num_updates): + """Called at the end of each step.""" + if ( + self.update_schedule is not None + and num_updates > 0 + and num_updates % self.update_schedule == 0 + ): + self.step() + + def state_dict(self): + return { + "n_centroids_config": self.n_centroids_config, + "block_sizes_config": self.block_sizes_config, + "layers_to_quantize": self.layers_to_quantize, + "epoch_schedule": self.epoch_schedule, + "update_schedule": self.update_schedule, + "quantization_step": self.quantization_step, + } + + def load_state_dict(self, state_dict): + self.n_centroids_config = state_dict["n_centroids_config"] + self.block_sizes_config = state_dict["block_sizes_config"] + self.layers_to_quantize = state_dict["layers_to_quantize"] + self.epoch_schedule = state_dict["epoch_schedule"] + self.update_schedule = state_dict["update_schedule"] + self.quantization_step = state_dict["quantization_step"] diff --git a/fairseq/registry.py b/fairseq/registry.py new file mode 100644 index 0000000..96994cb --- /dev/null +++ b/fairseq/registry.py @@ -0,0 +1,84 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace + +from typing import Union +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import populate_dataclass +from omegaconf import DictConfig + +REGISTRIES = {} + + +def setup_registry(registry_name: str, base_class=None, default=None, required=False): + assert registry_name.startswith("--") + registry_name = registry_name[2:].replace("-", "_") + + REGISTRY = {} + REGISTRY_CLASS_NAMES = set() + DATACLASS_REGISTRY = {} + + # maintain a registry of all registries + if registry_name in REGISTRIES: + return # registry already exists + REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default, "dataclass_registry": DATACLASS_REGISTRY} + + def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs): + if isinstance(cfg, DictConfig): + choice = cfg._name + elif isinstance(cfg, str): + choice = cfg + if choice in DATACLASS_REGISTRY: + cfg = DATACLASS_REGISTRY[choice]() + else: + choice = getattr(cfg, registry_name, None) + if choice in DATACLASS_REGISTRY: + cfg = populate_dataclass(cfg, DATACLASS_REGISTRY[choice]()) + + if choice is None: + if required: + raise ValueError('{} is required!'.format(registry_name)) + return None + + cls = REGISTRY[choice] + if hasattr(cls, "build_" + registry_name): + builder = getattr(cls, "build_" + registry_name) + else: + builder = cls + + return builder(cfg, *extra_args, **extra_kwargs) + + def register_x(name, dataclass=None): + def register_x_cls(cls): + if name in REGISTRY: + raise ValueError( + "Cannot register duplicate {} ({})".format(registry_name, name) + ) + if cls.__name__ in REGISTRY_CLASS_NAMES: + raise ValueError( + "Cannot register {} with duplicate class name ({})".format( + registry_name, cls.__name__ + ) + ) + if base_class is not None and not issubclass(cls, base_class): + raise ValueError( + "{} must extend {}".format(cls.__name__, base_class.__name__) + ) + + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + REGISTRY[name] = cls + if cls.__dataclass is not None: + DATACLASS_REGISTRY[name] = cls.__dataclass + return cls + + return register_x_cls + + return build_x, register_x, REGISTRY, DATACLASS_REGISTRY diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py new file mode 100644 index 0000000..9163be8 --- /dev/null +++ b/fairseq/scoring/__init__.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib +import os +from abc import ABC, abstractmethod + +from fairseq import registry +from omegaconf import DictConfig + + +class BaseScorer(ABC): + def __init__(self, cfg): + self.cfg = cfg + self.ref = [] + self.pred = [] + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + @abstractmethod + def score(self) -> float: + pass + + @abstractmethod + def result_string(self) -> str: + pass + + +_build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry( + "--scoring", default="bleu" +) + + +def build_scorer(choice, tgt_dict): + if isinstance(choice, DictConfig): + choice = choice._name + + if choice == "bleu": + from fairseq.scoring import bleu + + return bleu.Scorer( + bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk()) + ) + return _build_scorer(choice) + + +# automatically import any Python files in the current directory +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.scoring." + module) diff --git a/fairseq/scoring/bleu.py b/fairseq/scoring/bleu.py new file mode 100644 index 0000000..97de5f9 --- /dev/null +++ b/fairseq/scoring/bleu.py @@ -0,0 +1,167 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ctypes +import math +import sys +from dataclasses import dataclass, field + +import torch +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer +from fairseq.scoring.tokenizer import EvaluationTokenizer + + +class BleuStat(ctypes.Structure): + _fields_ = [ + ("reflen", ctypes.c_size_t), + ("predlen", ctypes.c_size_t), + ("match1", ctypes.c_size_t), + ("count1", ctypes.c_size_t), + ("match2", ctypes.c_size_t), + ("count2", ctypes.c_size_t), + ("match3", ctypes.c_size_t), + ("count3", ctypes.c_size_t), + ("match4", ctypes.c_size_t), + ("count4", ctypes.c_size_t), + ] + + +@dataclass +class SacrebleuConfig(FairseqDataclass): + sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="13a", metadata={"help": "tokenizer"} + ) + sacrebleu_lowercase: bool = field( + default=False, metadata={"help": "apply lowercasing"} + ) + sacrebleu_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + + +@register_scorer("sacrebleu", dataclass=SacrebleuConfig) +class SacrebleuScorer(BaseScorer): + def __init__(self, cfg): + super(SacrebleuScorer, self).__init__(cfg) + import sacrebleu + + self.sacrebleu = sacrebleu + self.tokenizer = EvaluationTokenizer( + tokenizer_type=cfg.sacrebleu_tokenizer, + lowercase=cfg.sacrebleu_lowercase, + character_tokenization=cfg.sacrebleu_char_level, + ) + + def add_string(self, ref, pred): + self.ref.append(self.tokenizer.tokenize(ref)) + self.pred.append(self.tokenizer.tokenize(pred)) + + def score(self, order=4): + return self.result_string(order).score + + def result_string(self, order=4): + if order != 4: + raise NotImplementedError + # tokenization and lowercasing are performed by self.tokenizer instead. + return self.sacrebleu.corpus_bleu( + self.pred, [self.ref], tokenize="none" + ).format() + + +@dataclass +class BleuConfig(FairseqDataclass): + pad: int = field(default=1, metadata={"help": "padding index"}) + eos: int = field(default=2, metadata={"help": "eos index"}) + unk: int = field(default=3, metadata={"help": "unk index"}) + + +@register_scorer("bleu", dataclass=BleuConfig) +class Scorer(object): + def __init__(self, cfg): + self.stat = BleuStat() + self.pad = cfg.pad + self.eos = cfg.eos + self.unk = cfg.unk + + try: + from fairseq import libbleu + except ImportError as e: + sys.stderr.write( + "ERROR: missing libbleu.so. run `pip install --editable .`\n" + ) + raise e + + self.C = ctypes.cdll.LoadLibrary(libbleu.__file__) + + self.reset() + + def reset(self, one_init=False): + if one_init: + self.C.bleu_one_init(ctypes.byref(self.stat)) + else: + self.C.bleu_zero_init(ctypes.byref(self.stat)) + + def add(self, ref, pred): + if not isinstance(ref, torch.IntTensor): + raise TypeError("ref must be a torch.IntTensor (got {})".format(type(ref))) + if not isinstance(pred, torch.IntTensor): + raise TypeError("pred must be a torch.IntTensor(got {})".format(type(pred))) + + # don't match unknown words + rref = ref.clone() + assert not rref.lt(0).any() + rref[rref.eq(self.unk)] = -999 + + rref = rref.contiguous().view(-1) + pred = pred.contiguous().view(-1) + + self.C.bleu_add( + ctypes.byref(self.stat), + ctypes.c_size_t(rref.size(0)), + ctypes.c_void_p(rref.data_ptr()), + ctypes.c_size_t(pred.size(0)), + ctypes.c_void_p(pred.data_ptr()), + ctypes.c_int(self.pad), + ctypes.c_int(self.eos), + ) + + def score(self, order=4): + psum = sum( + math.log(p) if p > 0 else float("-Inf") for p in self.precision()[:order] + ) + return self.brevity() * math.exp(psum / order) * 100 + + def precision(self): + def ratio(a, b): + return a / b if b > 0 else 0 + + return [ + ratio(self.stat.match1, self.stat.count1), + ratio(self.stat.match2, self.stat.count2), + ratio(self.stat.match3, self.stat.count3), + ratio(self.stat.match4, self.stat.count4), + ] + + def brevity(self): + r = self.stat.reflen / self.stat.predlen + return min(1, math.exp(1 - r)) + + def result_string(self, order=4): + assert order <= 4, "BLEU scores for order > 4 aren't supported" + fmt = "BLEU{} = {:2.2f}, {:2.1f}" + for _ in range(1, order): + fmt += "/{:2.1f}" + fmt += " (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})" + bleup = [p * 100 for p in self.precision()[:order]] + return fmt.format( + order, + self.score(order=order), + *bleup, + self.brevity(), + self.stat.predlen / self.stat.reflen, + self.stat.predlen, + self.stat.reflen + ) diff --git a/fairseq/scoring/chrf.py b/fairseq/scoring/chrf.py new file mode 100644 index 0000000..0d6cb77 --- /dev/null +++ b/fairseq/scoring/chrf.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.scoring import BaseScorer, register_scorer + + +@register_scorer("chrf") +class ChrFScorer(BaseScorer): + def __init__(self, args): + super(ChrFScorer, self).__init__(args) + import sacrebleu + + self.sacrebleu = sacrebleu + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + def score(self, order=4): + return self.result_string(order).score + + def result_string(self, order=4): + if order != 4: + raise NotImplementedError + return self.sacrebleu.corpus_chrf(self.pred, [self.ref]).format() diff --git a/fairseq/scoring/tokenizer.py b/fairseq/scoring/tokenizer.py new file mode 100644 index 0000000..61cf6d4 --- /dev/null +++ b/fairseq/scoring/tokenizer.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unicodedata + +from fairseq.dataclass import ChoiceEnum + + +class EvaluationTokenizer(object): + """A generic evaluation-time tokenizer, which leverages built-in tokenizers + in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides + lowercasing, punctuation removal and character tokenization, which are + applied after sacreBLEU tokenization. + + Args: + tokenizer_type (str): the type of sacreBLEU tokenizer to apply. + lowercase (bool): lowercase the text. + punctuation_removal (bool): remove punctuation (based on unicode + category) from text. + character_tokenization (bool): tokenize the text to characters. + """ + + SPACE = chr(32) + SPACE_ESCAPE = chr(9601) + ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"]) + + def __init__( + self, + tokenizer_type: str = "13a", + lowercase: bool = False, + punctuation_removal: bool = False, + character_tokenization: bool = False, + ): + from sacrebleu.tokenizers import TOKENIZERS + + assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}" + self.lowercase = lowercase + self.punctuation_removal = punctuation_removal + self.character_tokenization = character_tokenization + self.tokenizer = TOKENIZERS[tokenizer_type] + + @classmethod + def remove_punctuation(cls, sent: str): + """Remove punctuation based on Unicode category.""" + return cls.SPACE.join( + t + for t in sent.split(cls.SPACE) + if not all(unicodedata.category(c)[0] == "P" for c in t) + ) + + def tokenize(self, sent: str): + tokenized = self.tokenizer()(sent) + + if self.punctuation_removal: + tokenized = self.remove_punctuation(tokenized) + + if self.character_tokenization: + tokenized = self.SPACE.join( + list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)) + ) + + if self.lowercase: + tokenized = tokenized.lower() + + return tokenized diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py new file mode 100644 index 0000000..633dc47 --- /dev/null +++ b/fairseq/scoring/wer.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer +from fairseq.scoring.tokenizer import EvaluationTokenizer + + +@dataclass +class WerScorerConfig(FairseqDataclass): + wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"} + ) + wer_remove_punct: bool = field( + default=False, metadata={"help": "remove punctuation"} + ) + wer_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"}) + + +@register_scorer("wer", dataclass=WerScorerConfig) +class WerScorer(BaseScorer): + def __init__(self, cfg): + super().__init__(cfg) + self.reset() + try: + import editdistance as ed + except ImportError: + raise ImportError("Please install editdistance to use WER scorer") + self.ed = ed + self.tokenizer = EvaluationTokenizer( + tokenizer_type=self.cfg.wer_tokenizer, + lowercase=self.cfg.wer_lowercase, + punctuation_removal=self.cfg.wer_remove_punct, + character_tokenization=self.cfg.wer_char_level, + ) + + def reset(self): + self.distance = 0 + self.ref_length = 0 + + def add_string(self, ref, pred): + ref_items = self.tokenizer.tokenize(ref).split() + pred_items = self.tokenizer.tokenize(pred).split() + self.distance += self.ed.eval(ref_items, pred_items) + self.ref_length += len(ref_items) + + def result_string(self): + return f"WER: {self.score():.2f}" + + def score(self): + return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 diff --git a/fairseq/search.py b/fairseq/search.py new file mode 100644 index 0000000..d5ea68b --- /dev/null +++ b/fairseq/search.py @@ -0,0 +1,814 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import List, Optional + +import torch +import torch.nn as nn +from fairseq.token_generation_constraints import ( + ConstraintState, + OrderedConstraintState, + UnorderedConstraintState, +) +from torch import Tensor + + +class Search(nn.Module): + def __init__(self, tgt_dict): + super().__init__() + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() + self.vocab_size = len(tgt_dict) + self.src_lengths = torch.tensor(-1) + self.supports_constraints = False + self.stop_on_max_len = False + + def step( + self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None + ): + """Take a single search step. + + Args: + step: the current search step, starting at 0 + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + scores: (bsz x input_beam_size x step) + the historical model scores of each hypothesis up to this point + prev_output_tokens: (bsz x step) + the previously generated oputput tokens + original_batch_idxs: (bsz) + the tensor with the batch indices, in the range [0, bsz) + this is useful in case there has been applied a re-ordering + and we need to know the orignal indices + + Return: A tuple of (scores, indices, beams) where: + scores: (bsz x output_beam_size) + the scores of the chosen elements; output_beam_size can be + larger than input_beam_size, e.g., we may return + 2*input_beam_size to account for EOS + indices: (bsz x output_beam_size) + the indices of the chosen elements + beams: (bsz x output_beam_size) + the hypothesis ids of the chosen elements, in the range [0, input_beam_size) + """ + raise NotImplementedError + + @torch.jit.export + def set_src_lengths(self, src_lengths): + self.src_lengths = src_lengths + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): + """Initialize constraint states for constrained decoding (if supported). + + Args: + batch_constraints: (torch.Tensor, optional) + the list of constraints, in packed form + beam_size: (int) + the beam size + Returns: + *encoder_out* rearranged according to *new_order* + """ + pass + + def prune_sentences(self, batch_idxs: Tensor): + """ + Removes constraint states for completed sentences (if supported). + This is called from sequence_generator._generate() when sentences are + deleted from the batch. + + Args: + batch_idxs: Indices of *sentences* whose constraint state should be *kept*. + """ + pass + + def update_constraints(self, active_hypos: Tensor): + """ + Updates the constraint states by selecting the beam items that are retained. + This is called at each time step of sequence_generator._generate() when + the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size. + + Args: + active_hypos: (batch size, beam size) + list of integers denoting, for each sentence, which beam candidate items + should be kept. + """ + pass + + +class BeamSearch(Search): + def __init__(self, tgt_dict): + super().__init__(tgt_dict) + self.constraint_states = None + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + # Project back into relative indices and beams + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + + # At this point, beams_buf and indices_buf are single-dim and contain relative indices + return scores_buf, indices_buf, beams_buf + + +class PrefixConstrainedBeamSearch(Search): + def __init__(self, tgt_dict, prefix_allowed_tokens_fn): + super().__init__(tgt_dict) + self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn + self.stop_on_max_len = True + + @torch.jit.export + def apply_mask(self, x, prev_output_tokens, original_batch_idxs): + beam_size = x.shape[0] // original_batch_idxs.shape[0] + original_batch_idxs = ( + original_batch_idxs.unsqueeze(-1).repeat((1, beam_size)).flatten().tolist() + ) + + mask = torch.full_like(x, -math.inf) + for sent_i, (sent, batch_i) in enumerate( + zip(prev_output_tokens, original_batch_idxs) + ): + mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0 + + return mask + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Tensor, + prev_output_tokens: Tensor, + original_batch_idxs: Tensor, + ): + bsz, beam_size, vocab_size = lprobs.size() + + lprobs += self.apply_mask( + lprobs.view(bsz * beam_size, 1, vocab_size), + prev_output_tokens, + original_batch_idxs, + ).view(bsz, beam_size, vocab_size) + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + return scores_buf, indices_buf, beams_buf + + +class LexicallyConstrainedBeamSearch(Search): + """Implements lexically constrained beam search as described in + + Fast Lexically Constrained Decoding with Dynamic Beam + Allocation for Neural Machine Translation. Post & Vilar, + NAACL 2018. https://www.aclweb.org/anthology/N18-1119/ + + and + + Improved Lexically Constrained Decoding for Translation and + Monolingual Rewriting. Hu et al, NAACL + 2019. https://www.aclweb.org/anthology/N19-1090/ + + This is accomplished by maintaining, for each beam hypothesis, a + ConstraintState object (see constraints.py) that tracks which + constraints have been generated and using this information to + shape the beam for each input sentence. + """ + + def __init__(self, tgt_dict, representation): + super().__init__(tgt_dict) + self.representation = representation + self.vocab_size = len(tgt_dict) + self.num_cands = 0 + self.supports_constraints = True + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): + self.constraint_states = [] + for constraint_tensor in batch_constraints: + if self.representation == "ordered": + constraint_state = OrderedConstraintState.create(constraint_tensor) + elif self.representation == "unordered": + constraint_state = UnorderedConstraintState.create(constraint_tensor) + + self.constraint_states.append([constraint_state for i in range(beam_size)]) + + @torch.jit.export + def prune_sentences(self, batch_idxs: Tensor): + self.constraint_states = [ + self.constraint_states[i] for i in batch_idxs.tolist() + ] + + @torch.jit.export + def update_constraints(self, active_hypos: Tensor): + if self.constraint_states: + batch_size = active_hypos.size(0) + for sentid in range(batch_size): + self.constraint_states[sentid] = [ + self.constraint_states[sentid][i] for i in active_hypos[sentid] + ] + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + """ + A constrained step builds a large candidates list from the following: + - the top 2 * {beam_size} items over the whole beam + - for each item in the beam + - the top {each_k} (default 1) + - all next constraints + We then compute the constrained state of each beam item, and assign + stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so + on. We then sort by (stripe, score), and truncate the list at + 2 * beam size. + + Args: + step: the decoder step + lprobs: (batch size, beam size, target vocab) + the target-vocab distributions for each item in the beam. + Retrun: A tuple of (scores, indices, beams, constraints) where: + scores: (batch, output beam size) + the scores of the chosen elements + indices: (batch, output beam size) + the target vocab indices of the chosen elements + beams: (batch, output beam size) + the 0-indexed hypothesis ids of the chosen elements + constraints: (batch, output beam size) + the new constraint states + """ + each_k = 1 + device = lprobs.device + + batch_size, beam_size, vocab_size = lprobs.size() + + self.num_cands = min( + # Just take the k-best. We'll get another k from the 1-best from each + # row, plus more from the constraints + beam_size * 2, + lprobs.view(batch_size, -1).size(1) - 1, # -1 so we never select pad + ) + + # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items + constraint_states = self.constraint_states + if constraint_states and step > 0: + not_finished_indices = [] + for sentno, sent_constraints in enumerate(constraint_states): + for beamno, state in enumerate(sent_constraints): + index = sentno * beam_size + beamno + if not state.finished: + not_finished_indices.append(index) + not_finished_indices = torch.tensor(not_finished_indices) + if not_finished_indices.numel() > 0: + lprobs.view(batch_size * beam_size, -1)[ + not_finished_indices, self.eos + ] = -math.inf + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam entry for each batch item + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(batch_size, -1), + self.num_cands, + ) + scores_buf, indices_buf = top_prediction + # Project back into relative indices and beams + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + + # Short circuit if there are no constraints in this batch + if not constraint_states: + return scores_buf, indices_buf, beams_buf + + # STEP 1: get top-1 from each hypothesis across all sentences in the batch + if step > 0: + top_scores, top_indices = torch.topk( + lprobs.view(batch_size * beam_size, -1), + k=each_k, + dim=1, + ) + top_scores = top_scores.view(batch_size, -1) + top_indices = top_indices.view(batch_size, -1) + scores_buf = torch.cat((scores_buf, top_scores), dim=1) + indices_buf = torch.cat((indices_buf, top_indices), dim=1) + new_beams = torch.arange(0, beam_size, device=device).repeat(batch_size, 1) + beams_buf = torch.cat((beams_buf, new_beams), dim=1) + + # Now, process sentences in the batch one by one. + new_scores_buf = torch.zeros((batch_size, 2 * beam_size), device=device) + new_indices_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() + new_beams_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() + for sentno, states in enumerate(constraint_states): + scores, indices, beams, new_states = self.step_sentence( + step, + sentno, + lprobs[sentno], + constraint_states[sentno], + beams_buf[sentno].clone(), + indices_buf[sentno].clone(), + scores_buf[sentno].clone(), + ) + new_scores_buf[sentno] = scores + new_indices_buf[sentno] = indices + new_beams_buf[sentno] = beams + self.constraint_states[sentno] = new_states + + return new_scores_buf, new_indices_buf, new_beams_buf + + @torch.jit.export + def step_sentence( + self, + step: int, + sentno: int, + lprobs: Tensor, + constraint_states: List[List[ConstraintState]], + beams_buf: Tensor, + indices_buf: Tensor, + scores_buf: Tensor, + ): + """Does per-sentence processing. Adds all constraints for each + hypothesis to the list of candidates; then removes duplicates, + sorts, and dynamically stripes across the banks. All tensor inputs + are collapsed to those pertaining to a single input sentence. + """ + device = lprobs.device + + # STEP 2: Add all constraints for each beam item + for beamno, state in enumerate(constraint_states): + next_tokens = torch.tensor(list(state.next_tokens()), device=device).long() + if next_tokens.numel() != 0: + indices_buf = torch.cat((indices_buf, next_tokens)) + next_beams = ( + torch.tensor(beamno, device=device) + .repeat(next_tokens.size(0)) + .long() + ) + beams_buf = torch.cat((beams_buf, next_beams)) + next_values = lprobs[beamno].take(next_tokens.view(-1)) + scores_buf = torch.cat((scores_buf, next_values)) + + # At the 0th time step, there is just one beam item + if step == 0: + break + + # STEP 3: Compute the "bank" for each candidate. This is the + # number of constraints it's generated. We need this so that + # we can do round-robin allocation of the beam across these + # banks. If C is the number of constraints, we select the best + # item in bank C, then the best in bank C-1, etc, followed by + # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so + # on, until the maximum beam size. We accomplish this by + # creating a sort key and striping across the banks. + + # Compute the new states for all candidates + cands_size = indices_buf.size(0) + constraint_states = [ + constraint_states[beams_buf[i]].advance(indices_buf[i]) + for i in range(cands_size) + ] + + banks = torch.tensor([state.bank for state in constraint_states], device=device) + + # STEP 4: Sort + num_constraint_tokens = len(state.tokens) + + # Sort by keys (bank, score) (i.e., sort banks together, and scores + # within banks). AFAIK pytorch doesn't support either stable sort or + # multi-key sorting, so we have to hack this. + MAX_SCORE = -100 + sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf + sort_values, sort_indices = sort_key.sort(dim=0, descending=True) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + banks = banks[sort_indices] + + # Sort the constraints to follow suit + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 5: Remove duplicates. The topk calls (overall and + # per-row) plus the per-row generation of constraints will + # produce duplicates. Here we remove them. + + def roll(t): + """Rolls a 1d tensor left by 1. + + [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3] + """ + return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0) + + # We map candidates (beam, token_id) to a single dimension. + # This is then shifted by 1. We can then easily identify + # duplicates and create a mask that identifies unique + # extensions. + uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf + uniques_mask = roll(uniques_mask) != uniques_mask + + # Use the mask to pare down the data structures + scores_buf = torch.masked_select(scores_buf, uniques_mask) + indices_buf = torch.masked_select(indices_buf, uniques_mask) + beams_buf = torch.masked_select(beams_buf, uniques_mask) + banks = torch.masked_select(banks, uniques_mask) + i = 1 + for mask in uniques_mask[1:]: + if not mask: + constraint_states.pop(i) + i += mask + + # STEP 6: Assign IDs round-robin across banks, sort, and + # truncate. Now that the candidates are sorted by (bank, + # score) and uniqed, we dynamically allocate the {beam_size} + # beam by striping across the candidates. These stripes will + # be used as sort keys to do round-robin selection. This is + # accomplished in a single pass with offsets. Sorting by + # highest-banks (furthest-along hypotheses) first ensures + # progress through the constraints. + # + # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0 + # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1 + # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7 + # = 0 5 10 1 6 11 13 2 7 12 3 8 + # + # Sorting by this then gives the following banks: + # + # 3 2 1 0 3 2 1 0 3 2 1 2 + # + # We'll take the top {beam_size} of these. + stripe_offsets = [offset * (len(banks) + 1) for offset in range(len(banks) + 1)] + stripes = torch.zeros_like(banks) + cur_bank_count = -1 + cur_bank = banks[0] + for i, bank in enumerate(banks): + if bank != cur_bank: + cur_bank_count = 0 + cur_bank = bank + else: + cur_bank_count += 1 + stripes[i] = num_constraint_tokens - bank + stripe_offsets[cur_bank_count] + + # STEP 7: Sort by the stripes values + sort_values, sort_indices = stripes.sort(dim=0) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 8: Truncate to the candidates size! + scores_buf = scores_buf[: self.num_cands] + indices_buf = indices_buf[: self.num_cands] + beams_buf = beams_buf[: self.num_cands] + + return scores_buf, indices_buf, beams_buf, constraint_states + + +class LengthConstrainedBeamSearch(Search): + def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): + super().__init__(tgt_dict) + self.min_len_a = min_len_a + self.min_len_b = min_len_b + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.beam = BeamSearch(tgt_dict) + self.needs_src_lengths = True + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + min_lens = self.min_len_a * self.src_lengths + self.min_len_b + max_lens = self.max_len_a * self.src_lengths + self.max_len_b + lprobs[step < min_lens, :, self.eos] = -math.inf + lprobs[step >= max_lens, :, self.eos] = 0 + return self.beam.step(step, lprobs, scores) + + +class DiverseBeamSearch(Search): + """Diverse Beam Search. + + See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence + Models" for details. + + We only implement the Hamming Diversity penalty here, which performed best + in the original paper. + """ + + def __init__(self, tgt_dict, num_groups, diversity_strength): + super().__init__(tgt_dict) + self.num_groups = num_groups + self.diversity_strength = -diversity_strength + self.beam = BeamSearch(tgt_dict) + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + if beam_size % self.num_groups != 0: + raise ValueError( + "DiverseBeamSearch requires --beam to be divisible by the number of groups" + ) + + # initialize diversity penalty + diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs) + + scores_G, indices_G, beams_G = [], [], [] + for g in range(self.num_groups): + lprobs_g = lprobs[:, g :: self.num_groups, :] + scores_g = scores[:, g :: self.num_groups, :] if step > 0 else None + + # apply diversity penalty + if g > 0: + lprobs_g = torch.add( + lprobs_g, + other=diversity_buf.unsqueeze(1), + alpha=self.diversity_strength, + ) + else: + lprobs_g = lprobs_g.contiguous() + + scores_buf, indices_buf, beams_buf = self.beam.step( + step, lprobs_g, scores_g + ) + beams_buf.mul_(self.num_groups).add_(g) + + scores_G.append(scores_buf.clone()) + indices_G.append(indices_buf.clone()) + beams_G.append(beams_buf.clone()) + + # update diversity penalty + diversity_buf.scatter_add_( + 1, indices_buf, torch.ones(indices_buf.size()).to(diversity_buf) + ) + + # interleave results from different groups + scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1) + indices_buf = torch.stack(indices_G, dim=2).view(bsz, -1) + beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1) + return scores_buf, indices_buf, beams_buf + + +class Sampling(Search): + sampling_topk: int + sampling_topp: float + + def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0): + super().__init__(tgt_dict) + self.sampling_topk = sampling_topk + self.sampling_topp = sampling_topp + + def _sample_topp(self, lprobs): + """Sample among the smallest set of elements whose cumulative probability mass exceeds p. + + See `"The Curious Case of Neural Text Degeneration" + (Holtzman et al., 2019) `_. + + Args: + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + + Return: A tuple of (trimed_probs, truncated_indices) where: + trimed_probs: (bsz x input_beam_size x ?) + the model's probabilities over the elements selected to sample from. The + width of the third dimension is determined by top-P. + truncated_indices: (bsz x input_beam_size x ?) + the indices of the chosen elements. + """ + probs = lprobs.exp_() + + # sort the last dimension (vocab dimension) in descending order + sorted_probs, sorted_indices = probs.sort(descending=True) + + # compute a mask to indicate the words to be included in the top-P set. + cumsum_probs = sorted_probs.cumsum(dim=2) + mask = cumsum_probs.lt(self.sampling_topp) + + # note that mask was computed by 'lt'. One more word needs to be included + # so that the cumulative probability mass can exceed p. + cumsum_mask = mask.cumsum(dim=2) + last_included = cumsum_mask[:, :, -1:] + last_included.clamp_(0, mask.size()[2] - 1) + mask = mask.scatter_(2, last_included, 1) + + # truncate unnecessary dims. + max_dim = last_included.max() + truncated_mask = mask[:, :, : max_dim + 1] + truncated_probs = sorted_probs[:, :, : max_dim + 1] + truncated_indices = sorted_indices[:, :, : max_dim + 1] + + # trim the words that are not in top-P by setting their probabilities + # to 0, so that they would not be sampled later. + trim_mask = ~truncated_mask + trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) + return trimed_probs, truncated_indices + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + + if self.sampling_topp > 0: + # only sample from the smallest set of words whose cumulative probability mass exceeds p + probs, top_indices = self._sample_topp(lprobs) + elif self.sampling_topk > 0: + # only sample from top-k candidates + lprobs, top_indices = lprobs.topk(self.sampling_topk) + probs = lprobs.exp_() + else: + probs = lprobs.exp_() + + # dummy data to be consistent with true branch for type check + top_indices = torch.empty(0).to(probs) + # sample + if step == 0: + indices_buf = torch.multinomial( + probs.view(bsz, -1), + beam_size, + replacement=True, + ).view(bsz, beam_size) + else: + indices_buf = torch.multinomial( + probs.view(bsz * beam_size, -1), + 1, + replacement=True, + ).view(bsz, beam_size) + + if step == 0: + # expand to beam size + probs = probs.expand(bsz, beam_size, -1) + + # gather scores + scores_buf = torch.gather(probs, dim=2, index=indices_buf.unsqueeze(-1)) + scores_buf = scores_buf.log_().view(bsz, -1) + + # remap indices if using top-k or top-P sampling + if self.sampling_topk > 0 or self.sampling_topp > 0: + indices_buf = torch.gather( + top_indices.expand(bsz, beam_size, -1), + dim=2, + index=indices_buf.unsqueeze(-1), + ).squeeze(2) + + if step == 0: + beams_buf = indices_buf.new_zeros(bsz, beam_size) + else: + beams_buf = torch.arange(0, beam_size).to(indices_buf).repeat(bsz, 1) + # make scores cumulative + scores_buf.add_( + torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf) + ) + + return scores_buf, indices_buf, beams_buf + + +class DiverseSiblingsSearch(Search): + """ + Beam search with diverse siblings. + + See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details. + https://arxiv.org/abs/1611.08562 + + 1/ Calculate hypotheses for each beam + 2/ Intra-sibling ordering + 3/ Rewrite scores + 4/ Choose top K hypotheses + + if diversity_rate == 0 is equivalent to BeamSearch + """ + + def __init__(self, tgt_dict, diversity_rate): + super().__init__(tgt_dict) + self.diversity_rate = diversity_rate + self.beam = BeamSearch(tgt_dict) + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + k = min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ) + s_list: List[Tensor] + i_list: List[Tensor] + s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)] + i_list = [torch.LongTensor().to(device=lprobs.device) for i in range(beam_size)] + sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate + + if step == 0: + return self.beam.step(step, lprobs, scores) + lprobs.add_(scores[:, :, step - 1].unsqueeze(-1)) + + # 1/ Calculate hypotheses for each beam + for i in range(beam_size): + torch.topk(lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i])) + i_list[i].fmod_(vocab_size) + + # 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores + s_list[i].sub_(sibling_score) + + # 4/ Choose top K hypotheses + indices = torch.stack(i_list, dim=1).view(bsz, -1) + + final_scores = torch.empty(0).to(lprobs) + final_indices = torch.LongTensor().to(device=lprobs.device) + final_beams = torch.LongTensor().to(device=lprobs.device) + (final_scores, final_indices) = torch.topk( + torch.stack(s_list, dim=1).view(bsz, -1), + k, + ) + + final_beams = final_indices // k + + for i in range(bsz): + final_indices[i] = indices[i][final_indices[i]] + + return final_scores, final_indices, final_beams diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py new file mode 100644 index 0000000..9c5423e --- /dev/null +++ b/fairseq/sequence_generator.py @@ -0,0 +1,1005 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from fairseq import search, utils +from fairseq.data import data_utils +from fairseq.models import FairseqIncrementalDecoder +from fairseq.models.fairseq_encoder import EncoderOut +from torch import Tensor + + +class SequenceGenerator(nn.Module): + def __init__( + self, + models, + tgt_dict, + beam_size=1, + max_len_a=0, + max_len_b=200, + min_len=1, + normalize_scores=True, + len_penalty=1.0, + unk_penalty=0.0, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + search_strategy=None, + eos=None, + symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0, + ): + """Generates translations of a given source sentence. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + """ + super().__init__() + if isinstance(models, EnsembleModel): + self.model = models + else: + self.model = EnsembleModel(models) + self.tgt_dict = tgt_dict + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() if eos is None else eos + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) + self.vocab_size = len(tgt_dict) + self.beam_size = beam_size + # the max beam size is the dictionary size - 1, since we never select pad + self.beam_size = min(beam_size, self.vocab_size - 1) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + self.temperature = temperature + self.match_source_len = match_source_len + self.no_repeat_ngram_size = no_repeat_ngram_size + assert temperature > 0, "--temperature must be greater than 0" + + self.search = ( + search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy + ) + # We only need to set src_lengths in LengthConstrainedBeamSearch. + # As a module attribute, setting it would break in multithread + # settings when the model is shared. + self.should_set_src_lengths = ( + hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths + ) + + self.model.eval() + + self.lm_model = lm_model + self.lm_weight = lm_weight + if self.lm_model is not None: + self.lm_model.eval() + + def cuda(self): + self.model.cuda() + return self + + @torch.no_grad() + def forward( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + """Generate a batch of translations. + + Args: + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, prefix_tokens, bos_token=bos_token) + + # TODO(myleott): unused, deprecate after pytorch-translate migration + def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): + """Iterate over a batched dataset and yield individual translations. + Args: + cuda (bool, optional): use GPU for generation + timer (StopwatchMeter, optional): time generations + """ + for sample in data_itr: + s = utils.move_to_cuda(sample) if cuda else sample + if "net_input" not in s: + continue + input = s["net_input"] + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in input.items() if k != "prev_output_tokens" + } + if timer is not None: + timer.start() + with torch.no_grad(): + hypos = self.generate(encoder_input) + if timer is not None: + timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) + for i, id in enumerate(s["id"].data): + # remove padding + src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) + ref = ( + utils.strip_pad(s["target"].data[i, :], self.pad) + if s["target"] is not None + else None + ) + yield id, src, ref, hypos[i] + + @torch.no_grad() + def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): + """Generate translations. Match the api of other fairseq generators. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, **kwargs) + + def _generate( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) + net_input = sample["net_input"] + + if "src_tokens" in net_input: + src_tokens = net_input["src_tokens"] + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ( + (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) + ) + elif "source" in net_input: + src_tokens = net_input["source"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + else: + raise Exception("expected src_tokens or source in net input") + + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimenions (i.e. audio features) + bsz, src_len = src_tokens.size()[:2] + beam_size = self.beam_size + + if constraints is not None and not self.search.supports_constraints: + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) + + # Initialize constraints, when active + self.search.init_constraints(constraints, beam_size) + + max_len: int = -1 + if self.match_source_len: + max_len = src_lengths.max().item() + else: + max_len = min( + int(self.max_len_a * src_len + self.max_len_b), + # exclude the EOS marker + self.model.max_decoder_positions() - 1, + ) + assert ( + self.min_len <= max_len + ), "min_len cannot be larger than max_len, please adjust these!" + # compute the encoder output for each beam + encoder_outs = self.model.forward_encoder(net_input) + + # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_tokens.device).long() + encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) + # ensure encoder_outs is a List. + assert encoder_outs is not None + + # initialize buffers + scores = ( + torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float() + ) # +1 for eos; pad is never chosen for scoring + tokens = ( + torch.zeros(bsz * beam_size, max_len + 2) + .to(src_tokens) + .long() + .fill_(self.pad) + ) # +2 for eos and pad + tokens[:, 0] = self.eos if bos_token is None else bos_token + attn: Optional[Tensor] = None + + # A list that indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = ( + torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) + ) # forward and backward-compatible False mask + + # list of completed sentences + finalized = torch.jit.annotate( + List[List[Dict[str, Tensor]]], + [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], + ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step + + finished = [ + False for i in range(bsz) + ] # a boolean array indicating if the sentence at the index is finished or not + num_remaining_sent = bsz # number of sentences remaining + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) + cand_offsets = torch.arange(0, cand_size).type_as(tokens) + + reorder_state: Optional[Tensor] = None + batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if "id" in sample and isinstance(sample["id"], Tensor): + original_batch_idxs = sample["id"] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( + batch_idxs + ) + reorder_state.view(-1, beam_size).add_( + corr.unsqueeze(-1) * beam_size + ) + original_batch_idxs = original_batch_idxs[batch_idxs] + self.model.reorder_incremental_state(incremental_states, reorder_state) + encoder_outs = self.model.reorder_encoder_out( + encoder_outs, reorder_state + ) + + lprobs, avg_attn_scores = self.model.forward_decoder( + tokens[:, : step + 1], + encoder_outs, + incremental_states, + self.temperature, + ) + + if self.lm_model is not None: + lm_out = self.lm_model(tokens[:, : step + 1]) + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None + ) + probs = probs[:, -1, :] * self.lm_weight + lprobs += probs + + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) + + lprobs[:, self.pad] = -math.inf # never select pad + lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + + # handle max length constraint + if step >= max_len: + lprobs[:, : self.eos] = -math.inf + lprobs[:, self.eos + 1 :] = -math.inf + + # handle prefix tokens (possibly with different lengths) + if ( + prefix_tokens is not None + and step < prefix_tokens.size(1) + and step < max_len + ): + lprobs, tokens, scores = self._prefix_tokens( + step, lprobs, scores, tokens, prefix_tokens, beam_size + ) + elif step < self.min_len: + # minimum length constraint (does not apply if using prefix_tokens) + lprobs[:, self.eos] = -math.inf + + # Record attention scores, only support avg_attn_scores is a Tensor + if avg_attn_scores is not None: + if attn is None: + attn = torch.empty( + bsz * beam_size, avg_attn_scores.size(1), max_len + 2 + ).to(scores) + attn[:, :, step + 1].copy_(avg_attn_scores) + + scores = scores.type_as(lprobs) + eos_bbsz_idx = torch.empty(0).to( + tokens + ) # indices of hypothesis ending with eos (finished sentences) + eos_scores = torch.empty(0).to( + scores + ) # scores of hypothesis ending with eos (finished sentences) + + if self.should_set_src_lengths: + self.search.set_src_lengths(src_lengths) + + if self.no_repeat_ngram_size > 0: + lprobs = self._no_repeat_ngram(tokens, lprobs, bsz, beam_size, step) + + # Shape: (batch, cand_size) + cand_scores, cand_indices, cand_beams = self.search.step( + step, + lprobs.view(bsz, -1, self.vocab_size), + scores.view(bsz, beam_size, -1)[:, :, :step], + tokens[:, : step + 1], + original_batch_idxs, + ) + + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos + # Shape of eos_mask: (batch size, beam size) + eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) + eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) + + # only consider eos when it's among the top beam_size indices + # Now we know what beam item(s) to finish + # Shape: 1d list of absolute-numbered + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents: List[int] = [] + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.masked_select( + cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents = self.finalize_hypos( + step, + eos_bbsz_idx, + eos_scores, + tokens, + scores, + finalized, + finished, + beam_size, + attn, + src_lengths, + max_len, + ) + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + if self.search.stop_on_max_len and step >= max_len: + break + assert step < max_len + + # Remove finalized sentences (ones for which {beam_size} + # finished hypotheses have been generated) from the batch. + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = torch.ones( + bsz, dtype=torch.bool, device=cand_indices.device + ) + batch_mask[finalized_sents] = False + # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it + batch_idxs = torch.arange( + bsz, device=cand_indices.device + ).masked_select(batch_mask) + + # Choose the subset of the hypothesized constraints that will continue + self.search.prune_sentences(batch_idxs) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + cand_scores = cand_scores[batch_idxs] + cand_indices = cand_indices[batch_idxs] + + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths = src_lengths[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, attn.size(1), -1 + ) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos hypos + # and values < cand_size indicate candidate active hypos. + # After, the min values per row are the top candidate active hypos + + # Rewrite the operator since the element wise or is not supported in torchscript. + + eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[: eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just + # the hypos with the smallest values in active_mask. + # {active_hypos} indicates which {beam_size} hypotheses + # from the list of {2 * beam_size} candidates were + # selected. Shapes: (batch size, beam size) + new_cands_to_ignore, active_hypos = torch.topk( + active_mask, k=beam_size, dim=1, largest=False + ) + + # update cands_to_ignore to ignore any finalized hypos. + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + # Make sure there is at least one active item for each sentence in the batch. + assert (~cands_to_ignore).any(dim=1).all() + + # update cands_to_ignore to ignore any finalized hypos + + # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam + # can be selected more than once). + active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) + active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) + + active_bbsz_idx = active_bbsz_idx.view(-1) + active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + + # Set the tokens for each beam (can select the same row more than once) + tokens[:, : step + 1] = torch.index_select( + tokens[:, : step + 1], dim=0, index=active_bbsz_idx + ) + # Select the next token for each of them + tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( + cand_indices, dim=1, index=active_hypos + ) + if step > 0: + scores[:, :step] = torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx + ) + scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( + cand_scores, dim=1, index=active_hypos + ) + + # Update constraints based on which candidates were selected for the next beam + self.search.update_constraints(active_hypos) + + # copy attention for active hypotheses + if attn is not None: + attn[:, :, : step + 2] = torch.index_select( + attn[:, :, : step + 2], dim=0, index=active_bbsz_idx + ) + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # sort by score descending + for sent in range(len(finalized)): + scores = torch.tensor( + [float(elem["score"].item()) for elem in finalized[sent]] + ) + _, sorted_scores_indices = torch.sort(scores, descending=True) + finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] + finalized[sent] = torch.jit.annotate( + List[Dict[str, Tensor]], finalized[sent] + ) + return finalized + + def _prefix_tokens( + self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int + ): + """Handle prefix tokens""" + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) + prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + prefix_mask = prefix_toks.ne(self.pad) + lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) + lprobs[prefix_mask] = lprobs[prefix_mask].scatter( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] + ) + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ + :, 0, 1 : step + 1 + ] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) + scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) + lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) + return lprobs, tokens, scores + + def replicate_first_beam(self, tensor, mask, beam_size: int): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + def finalize_hypos( + self, + step: int, + bbsz_idx, + eos_scores, + tokens, + scores, + finalized: List[List[Dict[str, Tensor]]], + finished: List[bool], + beam_size: int, + attn: Optional[Tensor], + src_lengths, + max_len: int, + ): + """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. + A sentence is finalized when {beam_size} finished items have been collected for it. + + Returns number of sentences (not beam items) being finalized. + These will be removed from the batch and not processed further. + Args: + bbsz_idx (Tensor): + """ + assert bbsz_idx.numel() == eos_scores.numel() + + # clone relevant token and attention tensors. + # tokens is (batch * beam, max_len). So the index_select + # gets the newly EOS rows, then selects cols 1..{step + 2} + tokens_clone = tokens.index_select(0, bbsz_idx)[ + :, 1 : step + 2 + ] # skip the first index, which is EOS + + tokens_clone[:, step] = self.eos + attn_clone = ( + attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] + if attn is not None + else None + ) + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] + pos_scores[:, step] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + eos_scores /= (step + 1) ** self.len_penalty + + # cum_unfin records which sentences in the batch are finished. + # It helps match indexing between (a) the original sentences + # in the batch and (b) the current, possibly-reduced set of + # sentences. + cum_unfin: List[int] = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + + # The keys here are of the form "{sent}_{unfin_idx}", where + # "unfin_idx" is the index in the current (possibly reduced) + # list of sentences, and "sent" is the index in the original, + # unreduced batch + # set() is not supported in script export + sents_seen: Dict[str, Optional[Tensor]] = {} + + # For every finished beam item + for i in range(bbsz_idx.size()[0]): + idx = bbsz_idx[i] + score = eos_scores[i] + # sentence index in the current (possibly reduced) batch + unfin_idx = idx // beam_size + # sentence index in the original (unreduced) batch + sent = unfin_idx + cum_unfin[unfin_idx] + # Cannot create dict for key type '(int, int)' in torchscript. + # The workaround is to cast int to string + seen = str(sent.item()) + "_" + str(unfin_idx.item()) + if seen not in sents_seen: + sents_seen[seen] = None + + if self.match_source_len and step > src_lengths[unfin_idx]: + score = torch.tensor(-math.inf).to(score) + + # An input sentence (among those in a batch) is finished when + # beam_size hypotheses have been collected for it + if len(finalized[sent]) < beam_size: + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i] + else: + hypo_attn = torch.empty(0) + + finalized[sent].append( + { + "tokens": tokens_clone[i], + "score": score, + "attention": hypo_attn, # src_len x tgt_len + "alignment": torch.empty(0), + "positional_scores": pos_scores[i], + } + ) + + newly_finished: List[int] = [] + + for seen in sents_seen.keys(): + # check termination conditions for this sentence + sent: int = int(float(seen.split("_")[0])) + unfin_idx: int = int(float(seen.split("_")[1])) + + if not finished[sent] and self.is_finished( + step, unfin_idx, max_len, len(finalized[sent]), beam_size + ): + finished[sent] = True + newly_finished.append(unfin_idx) + + return newly_finished + + def is_finished( + self, + step: int, + unfin_idx: int, + max_len: int, + finalized_sent_len: int, + beam_size: int, + ): + """ + Check whether decoding for a sentence is finished, which + occurs when the list of finalized sentences has reached the + beam size, or when we reach the maximum length. + """ + assert finalized_sent_len <= beam_size + if finalized_sent_len == beam_size or step == max_len: + return True + return False + + def calculate_banned_tokens( + self, + tokens, + step: int, + gen_ngrams: List[Dict[str, List[int]]], + no_repeat_ngram_size: int, + bbsz_idx: int, + ): + tokens_list: List[int] = tokens[ + bbsz_idx, step + 2 - no_repeat_ngram_size : step + 1 + ].tolist() + # before decoding the next token, prevent decoding of ngrams that have already appeared + ngram_index = ",".join([str(x) for x in tokens_list]) + return gen_ngrams[bbsz_idx].get(ngram_index, torch.jit.annotate(List[int], [])) + + def transpose_list(self, l: List[List[int]]): + # GeneratorExp aren't supported in TS so ignoring the lint + min_len = min([len(x) for x in l]) # noqa + l2 = [[row[i] for row in l] for i in range(min_len)] + return l2 + + def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): + # for each beam and batch sentence, generate a list of previous ngrams + gen_ngrams: List[Dict[str, List[int]]] = [ + torch.jit.annotate(Dict[str, List[int]], {}) + for bbsz_idx in range(bsz * beam_size) + ] + cpu_tokens = tokens.cpu() + for bbsz_idx in range(bsz * beam_size): + gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() + for ngram in self.transpose_list( + [gen_tokens[i:] for i in range(self.no_repeat_ngram_size)] + ): + key = ",".join([str(x) for x in ngram[:-1]]) + gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( + key, torch.jit.annotate(List[int], []) + ) + [ngram[-1]] + + if step + 2 - self.no_repeat_ngram_size >= 0: + # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + banned_tokens = [ + self.calculate_banned_tokens( + tokens, step, gen_ngrams, self.no_repeat_ngram_size, bbsz_idx + ) + for bbsz_idx in range(bsz * beam_size) + ] + else: + banned_tokens = [ + torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size) + ] + for bbsz_idx in range(bsz * beam_size): + lprobs[bbsz_idx][ + torch.tensor(banned_tokens[bbsz_idx]).long() + ] = torch.tensor(-math.inf).to(lprobs) + return lprobs + + +class EnsembleModel(nn.Module): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__() + self.models_size = len(models) + # method '__len__' is not supported in ModuleList for torch script + self.single_model = models[0] + self.models = nn.ModuleList(models) + + self.has_incremental: bool = False + if all( + hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) + for m in models + ): + self.has_incremental = True + + def forward(self): + pass + + def has_encoder(self): + return hasattr(self.single_model, "encoder") + + def has_incremental_states(self): + return self.has_incremental + + def max_decoder_positions(self): + return min([m.max_decoder_positions() for m in self.models]) + + @torch.jit.export + def forward_encoder(self, net_input: Dict[str, Tensor]): + if not self.has_encoder(): + return None + return [model.encoder.forward_torchscript(net_input) for model in self.models] + + @torch.jit.export + def forward_decoder( + self, + tokens, + encoder_outs: List[EncoderOut], + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + temperature: float = 1.0, + ): + log_probs = [] + avg_attn: Optional[Tensor] = None + encoder_out: Optional[EncoderOut] = None + for i, model in enumerate(self.models): + if self.has_encoder(): + encoder_out = encoder_outs[i] + # decode each model + if self.has_incremental_states(): + decoder_out = model.decoder.forward( + tokens, + encoder_out=encoder_out, + incremental_state=incremental_states[i], + ) + else: + decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + + attn: Optional[Tensor] = None + decoder_len = len(decoder_out) + if decoder_len > 1 and decoder_out[1] is not None: + if isinstance(decoder_out[1], Tensor): + attn = decoder_out[1] + else: + attn_holder = decoder_out[1]["attn"] + if isinstance(attn_holder, Tensor): + attn = attn_holder + elif attn_holder is not None: + attn = attn_holder[0] + if attn is not None: + attn = attn[:, -1, :] + + decoder_out_tuple = ( + decoder_out[0][:, -1:, :].div_(temperature), + None if decoder_len <= 1 else decoder_out[1], + ) + + probs = model.get_normalized_probs( + decoder_out_tuple, log_probs=True, sample=None + ) + probs = probs[:, -1, :] + if self.models_size == 1: + return probs, attn + + log_probs.append(probs) + if attn is not None: + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + + avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( + self.models_size + ) + + if avg_attn is not None: + avg_attn.div_(self.models_size) + return avg_probs, avg_attn + + @torch.jit.export + def reorder_encoder_out(self, encoder_outs: Optional[List[EncoderOut]], new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_outs: List[EncoderOut] = [] + if not self.has_encoder(): + return new_outs + for i, model in enumerate(self.models): + assert encoder_outs is not None + new_outs.append( + model.encoder.reorder_encoder_out(encoder_outs[i], new_order) + ) + return new_outs + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + new_order, + ): + if not self.has_incremental_states(): + return + for i, model in enumerate(self.models): + model.decoder.reorder_incremental_state_scripting( + incremental_states[i], new_order + ) + + +class SequenceGeneratorWithAlignment(SequenceGenerator): + def __init__(self, models, tgt_dict, left_pad_target=False, **kwargs): + """Generates translations of a given source sentence. + + Produces alignments following "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + left_pad_target (bool, optional): Whether or not the + hypothesis should be left padded or not when they are + teacher forced for generating alignments. + """ + super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) + self.left_pad_target = left_pad_target + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + finalized = super()._generate(sample, **kwargs) + + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + beam_size = self.beam_size + ( + src_tokens, + src_lengths, + prev_output_tokens, + tgt_tokens, + ) = self._prepare_batch_for_alignment(sample, finalized) + if any(getattr(m, "full_context_alignment", False) for m in self.model.models): + attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) + else: + attn = [ + finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0) + for i in range(bsz * beam_size) + ] + + if src_tokens.device != "cpu": + src_tokens = src_tokens.to("cpu") + tgt_tokens = tgt_tokens.to("cpu") + attn = [i.to("cpu") for i in attn] + + # Process the attn matrix to extract hard alignments. + for i in range(bsz * beam_size): + alignment = utils.extract_hard_alignment( + attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos + ) + finalized[i // beam_size][i % beam_size]["alignment"] = alignment + return finalized + + def _prepare_batch_for_alignment(self, sample, hypothesis): + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + src_tokens = ( + src_tokens[:, None, :] + .expand(-1, self.beam_size, -1) + .contiguous() + .view(bsz * self.beam_size, -1) + ) + src_lengths = sample["net_input"]["src_lengths"] + src_lengths = ( + src_lengths[:, None] + .expand(-1, self.beam_size) + .contiguous() + .view(bsz * self.beam_size) + ) + prev_output_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=True, + ) + tgt_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=False, + ) + return src_tokens, src_lengths, prev_output_tokens, tgt_tokens + + +class EnsembleModelWithAlignment(EnsembleModel): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__(models) + + def forward_align(self, src_tokens, src_lengths, prev_output_tokens): + avg_attn = None + for model in self.models: + decoder_out = model(src_tokens, src_lengths, prev_output_tokens) + attn = decoder_out[1]["attn"][0] + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + if len(self.models) > 1: + avg_attn.div_(len(self.models)) + return avg_attn diff --git a/fairseq/sequence_scorer.py b/fairseq/sequence_scorer.py new file mode 100644 index 0000000..411d4df --- /dev/null +++ b/fairseq/sequence_scorer.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +import torch +from fairseq import utils + + +class SequenceScorer(object): + """Scores the target for a given source sentence.""" + + def __init__( + self, + tgt_dict, + softmax_batch=None, + compute_alignment=False, + eos=None, + symbols_to_strip_from_output=None, + ): + self.pad = tgt_dict.pad() + self.eos = tgt_dict.eos() if eos is None else eos + self.softmax_batch = softmax_batch or sys.maxsize + assert self.softmax_batch > 0 + self.compute_alignment = compute_alignment + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + """Score a batch of translations.""" + net_input = sample["net_input"] + + def batch_for_softmax(dec_out, target): + # assumes decoder_out[0] is the only thing needed (may not be correct for future models!) + first, rest = dec_out[0], dec_out[1:] + bsz, tsz, dim = first.shape + if bsz * tsz < self.softmax_batch: + yield dec_out, target, True + else: + flat = first.contiguous().view(1, -1, dim) + flat_tgt = target.contiguous().view(flat.shape[:-1]) + s = 0 + while s < flat.size(1): + e = s + self.softmax_batch + yield (flat[:, s:e],) + rest, flat_tgt[:, s:e], False + s = e + + def gather_target_probs(probs, target): + probs = probs.gather( + dim=2, + index=target.unsqueeze(-1), + ) + return probs + + orig_target = sample["target"] + + # compute scores for each model in the ensemble + avg_probs = None + avg_attn = None + for model in models: + model.eval() + decoder_out = model(**net_input) + attn = decoder_out[1] if len(decoder_out) > 1 else None + if type(attn) is dict: + attn = attn.get("attn", None) + + batched = batch_for_softmax(decoder_out, orig_target) + probs, idx = None, 0 + for bd, tgt, is_single in batched: + sample["target"] = tgt + curr_prob = model.get_normalized_probs( + bd, log_probs=len(models) == 1, sample=sample + ).data + if is_single: + probs = gather_target_probs(curr_prob, orig_target) + else: + if probs is None: + probs = curr_prob.new(orig_target.numel()) + step = curr_prob.size(0) * curr_prob.size(1) + end = step + idx + tgt_probs = gather_target_probs( + curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt + ) + probs[idx:end] = tgt_probs.view(-1) + idx = end + sample["target"] = orig_target + + probs = probs.view(sample["target"].shape) + + if avg_probs is None: + avg_probs = probs + else: + avg_probs.add_(probs) + if attn is not None: + if torch.is_tensor(attn): + attn = attn.data + else: + attn = attn[0] + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + if len(models) > 1: + avg_probs.div_(len(models)) + avg_probs.log_() + if avg_attn is not None: + avg_attn.div_(len(models)) + + bsz = avg_probs.size(0) + hypos = [] + start_idxs = sample["start_indices"] if "start_indices" in sample else [0] * bsz + for i in range(bsz): + # remove padding from ref + ref = ( + utils.strip_pad(sample["target"][i, start_idxs[i] :], self.pad) + if sample["target"] is not None + else None + ) + tgt_len = ref.numel() + avg_probs_i = avg_probs[i][start_idxs[i] : start_idxs[i] + tgt_len] + score_i = avg_probs_i.sum() / tgt_len + if avg_attn is not None: + avg_attn_i = avg_attn[i] + if self.compute_alignment: + alignment = utils.extract_hard_alignment( + avg_attn_i, + sample["net_input"]["src_tokens"][i], + sample["target"][i], + self.pad, + self.eos, + ) + else: + alignment = None + else: + avg_attn_i = alignment = None + hypos.append( + [ + { + "tokens": ref, + "score": score_i, + "attention": avg_attn_i, + "alignment": alignment, + "positional_scores": avg_probs_i, + } + ] + ) + return hypos diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py new file mode 100644 index 0000000..41f461f --- /dev/null +++ b/fairseq/tasks/__init__.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import argparse +import importlib +import os + +from fairseq.dataclass import FairseqDataclass +from omegaconf import DictConfig + +from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa + + +# register dataclass +TASK_DATACLASS_REGISTRY = {} +TASK_REGISTRY = {} +TASK_CLASS_NAMES = set() + + +def setup_task(cfg: DictConfig, **kwargs): + if isinstance(cfg, DictConfig): + return TASK_REGISTRY[cfg._name].setup_task(cfg, **kwargs) + return TASK_REGISTRY[cfg.task].setup_task(cfg, **kwargs) + + +def register_task(name, dataclass=None): + """ + New tasks can be added to fairseq with the + :func:`~fairseq.tasks.register_task` function decorator. + + For example:: + + @register_task('classification') + class ClassificationTask(FairseqTask): + (...) + + .. note:: + + All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` + interface. + + Args: + name (str): the name of the task + """ + + def register_task_cls(cls): + if name in TASK_REGISTRY: + raise ValueError("Cannot register duplicate task ({})".format(name)) + if not issubclass(cls, FairseqTask): + raise ValueError( + "Task ({}: {}) must extend FairseqTask".format(name, cls.__name__) + ) + if cls.__name__ in TASK_CLASS_NAMES: + raise ValueError( + "Cannot register task with duplicate class name ({})".format( + cls.__name__ + ) + ) + TASK_REGISTRY[name] = cls + TASK_CLASS_NAMES.add(cls.__name__) + + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + if dataclass is not None: + TASK_DATACLASS_REGISTRY[name] = dataclass + + return cls + + return register_task_cls + + +def get_task(name): + return TASK_REGISTRY[name] + + +# automatically import any Python files in the tasks/ directory +tasks_dir = os.path.dirname(__file__) +for file in os.listdir(tasks_dir): + path = os.path.join(tasks_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + task_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("fairseq.tasks." + task_name) + + # expose `task_parser` for sphinx + if task_name in TASK_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_task = parser.add_argument_group("Task name") + # fmt: off + group_task.add_argument('--task', metavar=task_name, + help='Enable this task with: ``--task=' + task_name + '``') + # fmt: on + group_args = parser.add_argument_group("Additional command-line arguments") + TASK_REGISTRY[task_name].add_args(group_args) + globals()[task_name + "_parser"] = parser diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py new file mode 100644 index 0000000..06b4abf --- /dev/null +++ b/fairseq/tasks/audio_pretraining.py @@ -0,0 +1,335 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import editdistance +import os +import sys +import numpy as np +import logging +import pdb +import torch + +from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset, ResamplingDataset, ConcatDataset, encoders +from fairseq.data.data_utils import post_process + +from . import LegacyFairseqTask, register_task +from .. import utils +from ..logging import metrics + + +logger = logging.getLogger(__name__) + +class LabelEncoder(object): + def __init__(self, dictionary): + self.dictionary = dictionary + + def __call__(self, label): + return self.dictionary.encode_line( + label, append_eos=False, add_if_not_exist=False + ) + +class IDEncoder(object): + def __call__(self, ids): + ids = list(map(int, ids)) + idx = torch.IntTensor(ids) + return idx + + +@register_task("audio_pretraining") +class AudioPretrainingTask(LegacyFairseqTask): + """""" + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("data", help="path to data directory") + parser.add_argument("--train-path", default="", type=str) + parser.add_argument( + "--sample-rate", + default=16000, + type=int, + help="target sample rate. audio files will be up/down sampled to this rate", + ) + parser.add_argument( + "--normalize", + action="store_true", + help="if set, normalizes input to have 0 mean and unit variance", + ) + parser.add_argument( + "--max-sample-size", + default=None, + type=int, + help="max sample size to crop to for batching. default = min sample length", + ) + parser.add_argument( + "--min-sample-size", + default=None, + type=int, + help="min sample size to crop to for batching. default = same as --max-sample-size", + ) + + parser.add_argument( + "--enable-padding", + action="store_true", + help="pad shorter samples instead of cropping", + ) + + parser.add_argument( + "--labels", + type=str, + default=None, + help="extension of the label file to load, if any", + ) + + parser.add_argument( + "--multilang-sampling-alpha", + type=float, + default=0.5, + help="smoothing alpha for sample rations across multiple datasets", + ) + + parser.add_argument( + "--dict-path", + type=str, + default=None, + ) + # Options for reporting WER metrics during validation. Only applicable to + # Seq2Seq models during fine-tuning + parser.add_argument( + "--eval-wer", + action="store_true", + help="compute WER for Seq2Seq models", + ) + parser.add_argument( + "--eval-wer-remove-bpe", + default="letter", + help="remove BPE tokens before scoring (can be sentencepiece, letter, and more)", + ) + + def __init__(self, args, source_dictionary=None, target_dictionary=None): + super().__init__(args) + self._target_dictionary = target_dictionary + self._source_dictionary = source_dictionary + self.is_ctc = args.criterion == "ctc" + if getattr(self.args, "eval_wer", False): + assert args.labels is not None, "eval_wer can only be set during fine-tuning" + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (omegaconf.DictConfig): parsed command-line arguments + """ + + if args.labels: + if args.dict_path is not None: + target_dictionary = Dictionary(args.dict_path) + else: + dict_path = os.path.join(args.data, f"dict.{args.labels}.txt") + target_dictionary = Dictionary.load(dict_path) + else: + target_dictionary = None + + return cls(args, target_dictionary=target_dictionary) + + def _get_sample_prob(self, dataset_lens): + """ + Get smoothed sampling porbability by languages. This helps low resource + languages by upsampling them. + """ + prob = dataset_lens / dataset_lens.sum() + smoothed_prob = prob ** self.args.multilang_sampling_alpha + smoothed_prob = smoothed_prob / smoothed_prob.sum() + return smoothed_prob + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + manifest_list = split.split(",") + datasets = [] + datasets_lengths = [] + for f in manifest_list: + manifest = os.path.join(self.args.data, "{}.tsv".format(f)) + dataset = FileAudioDataset( + manifest, + sample_rate=self.args.sample_rate, + max_sample_size=self.args.max_sample_size, + min_sample_size=self.args.max_sample_size, + min_length=self.args.min_sample_size, + pad=self.args.labels is not None or self.args.enable_padding, + normalize=self.args.normalize, + ) + + if self.args.labels: + if self.args.dict_path is not None: + dict_path = self.args.dict_path + else: + dict_path = os.path.join(self.args.data, f"dict.{self.args.labels}.txt") + self._target_dictionary = Dictionary(dict_path) + label_path = os.path.join(self.args.data, "{}.id".format(f)) + labels = [] + count = 0 + itr = 0 + with open(label_path, "r") as f: + for line in f: + count += 1 + if itr < len(dataset.skipped) and count == dataset.skipped[itr]: + itr += 1 + continue + line = line.strip().split() + labels.append(line) + + process_label = IDEncoder() + assert len(labels) == len(dataset) + + dataset = AddTargetDataset( + dataset, + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + add_to_input=False, + ) + datasets.append(dataset) + datasets_lengths.append(sum(dataset.sizes) / self.args.sample_rate / 3600) + + if len(manifest_list) == 1: + self.datasets[split] = dataset + else: + languages = [manifest.split('/')[0] for manifest in manifest_list] + datasets_lengths = np.array(datasets_lengths) + sample_probs = self._get_sample_prob(datasets_lengths) + for id, lang in enumerate(languages): + logger.info( + "Sample probability by language: {} : {:.4f}".format(lang, sample_probs[id]) + ) + size_ratio = (sample_probs * datasets_lengths.sum()) / datasets_lengths + for id, lang in enumerate(languages): + logger.info( + "Up/Down Sampling ratio by language: {} : {:.2f}".format(lang, size_ratio[id]) + ) + + resampled_lang_datasets = [ + ResamplingDataset( + datasets[i], + size_ratio=size_ratio[i], + seed=self.args.seed, + epoch=epoch, + replace=size_ratio[i] >= 1.0, + ) + for i, d in enumerate(datasets) + ] + self.datasets[split] = ConcatDataset(resampled_lang_datasets) + + @property + def source_dictionary(self): + return self._source_dictionary + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self._target_dictionary + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return (sys.maxsize, sys.maxsize) + + def filter_indices_by_size( + self, + indices, + dataset, + max_positions=None, + ignore_invalid_inputs=False, + ): + # we do not need to filter by size in this task as dataloaders take care of this + return indices + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + + if getattr(self.args, "eval_wer", False) and not self.is_ctc: + metrics = self._inference_with_wer(self.sequence_generator, sample, model) + logging_output["_num_char_errors"] = metrics["num_char_errors"] + logging_output["_num_chars"] = metrics["num_chars"] + logging_output["_num_word_errors"] = metrics["num_word_errors"] + logging_output["_num_words"] = metrics["num_words"] + return loss, sample_size, logging_output + + def build_model(self, args): + model = super().build_model(args) + + if getattr(args, 'eval_wer', False) and not self.is_ctc: + self.sequence_generator = self.build_generator([model], args, ) + self.tokenizer = encoders.build_tokenizer(args) + return model + + def _inference_with_wer(self, generator, sample, model): + def decode(toks, escape_unk=True): + s = self.target_dictionary.string( + toks.int().cpu(), + self.args.eval_wer_remove_bpe, + escape_unk=escape_unk, + extra_symbols_to_ignore={generator.eos}, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + num_word_errors, num_char_errors = 0, 0 + num_chars, num_words = 0, 0 + gen_out = self.inference_step(generator, [model], sample, None) + for i in range(len(gen_out)): + hyp = decode(gen_out[i][0]["tokens"]) + ref = decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + escape_unk=True, + ) + hyp = post_process(hyp, self.args.eval_wer_remove_bpe).strip("_") + ref = post_process(ref, self.args.eval_wer_remove_bpe).strip("_") + num_char_errors += editdistance.eval(hyp, ref) + num_chars += len(ref) + hyp_words = hyp.split("_") + ref_words = ref.split("_") + num_word_errors += editdistance.eval(hyp_words, ref_words) + num_words += len(ref_words) + + return { + "num_char_errors": num_char_errors, + "num_chars": num_chars, + "num_word_errors": num_word_errors, + "num_words": num_words, + } + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + zero = torch.scalar_tensor(0.) + num_char_errors = sum(log.get("_num_char_errors", zero) for log in logging_outputs) + num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) + num_word_errors = sum(log.get("_num_word_errors", zero) for log in logging_outputs) + num_words = sum(log.get("_num_words", zero) for log in logging_outputs) + metrics.log_scalar("_num_char_errors", num_char_errors) + metrics.log_scalar("_num_chars", num_chars) + metrics.log_scalar("_num_word_errors", num_word_errors) + metrics.log_scalar("_num_words", num_words) + if num_words > 0: + metrics.log_derived( + "uer", + lambda meters: meters["_num_char_errors"].sum * 100.0 / meters["_num_chars"].sum + if meters["_num_chars"].sum > 0 else float("nan") + ) + metrics.log_derived( + "wer", + lambda meters: meters["_num_word_errors"].sum * 100.0 / meters["_num_words"].sum + if meters["_num_words"].sum > 0 else float("nan") + ) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py new file mode 100644 index 0000000..3cdb64c --- /dev/null +++ b/fairseq/tasks/fairseq_task.py @@ -0,0 +1,567 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import warnings +from argparse import Namespace + +import torch +from fairseq import metrics, search, tokenizer, utils +from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators +from fairseq.dataclass.utils import gen_parser_from_dataclass +from omegaconf import DictConfig + + +logger = logging.getLogger(__name__) + + +class FairseqTask(object): + """ + Tasks store dictionaries and provide helpers for loading/iterating over + Datasets, initializing the Model/Criterion and calculating the loss. + """ + + @classmethod + def add_args(cls, parser): + """Add task-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + @staticmethod + def logging_outputs_can_be_summed(criterion) -> bool: + """ + Whether the logging outputs returned by `train_step` and `valid_step` can + be summed across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return criterion.logging_outputs_can_be_summed() + + def __init__(self, cfg: DictConfig, **kwargs): + self.cfg = cfg + self.datasets = {} + self.dataset_to_epoch_iter = {} + + @classmethod + def load_dictionary(cls, filename): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + return Dictionary.load(filename) + + @classmethod + def build_dictionary( + cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 + ): + """Build the dictionary + + Args: + filenames (list): list of filenames + workers (int): number of concurrent workers + threshold (int): defines the minimum word count + nwords (int): defines the total number of words in the final dictionary, + including special symbols + padding_factor (int): can be used to pad the dictionary size to be a + multiple of 8, which is important on some hardware (e.g., Nvidia + Tensor Cores). + """ + d = Dictionary() + for filename in filenames: + Dictionary.add_file_to_dictionary( + filename, d, tokenizer.tokenize_line, workers + ) + d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) + return d + + @classmethod + def setup_task(cls, cfg: DictConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (omegaconf.DictConfig): parsed command-line arguments + """ + return cls(cfg, **kwargs) + + def has_sharded_data(self, split): + return os.pathsep in getattr(self.cfg, "data", "") + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + raise NotImplementedError + + def dataset(self, split): + """ + Return a loaded dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + + Returns: + a :class:`~fairseq.data.FairseqDataset` corresponding to *split* + """ + from fairseq.data import FairseqDataset + + if split not in self.datasets: + raise KeyError("Dataset not loaded: " + split) + if not isinstance(self.datasets[split], FairseqDataset): + raise TypeError("Datasets are expected to be of type FairseqDataset") + return self.datasets[split] + + def filter_indices_by_size( + self, indices, dataset, max_positions=None, ignore_invalid_inputs=False + ): + """ + Filter examples that are too large + + Args: + indices (np.array): original array of sample indices + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + Returns: + np.array: array of filtered sample indices + """ + indices, ignored = dataset.filter_indices_by_size(indices, max_positions) + if len(ignored) > 0: + if not ignore_invalid_inputs: + raise Exception( + ( + "Size of sample #{} is invalid (={}) since max_positions={}, " + "skip this example with --skip-invalid-size-inputs-valid-test" + ).format(ignored[0], dataset.size(ignored[0]), max_positions) + ) + logger.warning( + ( + "{} samples have invalid sizes and will be skipped, " + "max_positions={}, first few sample ids={}" + ).format(len(ignored), max_positions, ignored[:10]) + ) + return indices + + def can_reuse_epoch_itr(self, dataset): + # We can reuse the epoch iterator across epochs as long as the dataset + # hasn't disabled it. We default to ``False`` here, although in practice + # this will be ``True`` for most datasets that inherit from + # ``FairseqDataset`` due to the base implementation there. + return getattr(dataset, "can_reuse_epoch_itr_across_epochs", False) + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + ): + """ + Get an iterator that yields batches of data from the given dataset. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + data_buffer_size (int, optional): number of batches to + preload (default: 0). + disable_iterator_cache (bool, optional): don't cache the + EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) + (default: False). + Returns: + ~fairseq.iterators.EpochBatchIterator: a batched iterator over the + given dataset split + """ + can_reuse_epoch_itr = not disable_iterator_cache and self.can_reuse_epoch_itr( + dataset + ) + if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: + logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch)) + return self.dataset_to_epoch_iter[dataset] + + assert isinstance(dataset, FairseqDataset) + + # initialize the dataset with the correct starting epoch + dataset.set_epoch(epoch) + + # get indices ordered by example size + with data_utils.numpy_seed(seed): + indices = dataset.ordered_indices() + + # filter examples that are too large + if max_positions is not None: + indices = self.filter_indices_by_size( + indices, dataset, max_positions, ignore_invalid_inputs + ) + + # create mini-batches with given size constraints + batch_sampler = dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + + # return a reusable, sharded iterator + epoch_iter = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + buffer_size=data_buffer_size, + ) + + if can_reuse_epoch_itr: + self.dataset_to_epoch_iter[dataset] = epoch_iter + + return epoch_iter + + def build_model(self, cfg: DictConfig): + """ + Build the :class:`~fairseq.models.BaseFairseqModel` instance for this + task. + + Args: + cfg (omegaconf.DictConfig): configuration object + + Returns: + a :class:`~fairseq.models.BaseFairseqModel` instance + """ + from fairseq import models, quantization_utils + + model = models.build_model(cfg, self) + if getattr(cfg, "tpu", False): + model.prepare_for_tpu_() + model = quantization_utils.quantize_model_scalar(model, cfg) + return model + + def build_criterion(self, cfg: DictConfig): + """ + Build the :class:`~fairseq.criterions.FairseqCriterion` instance for + this task. + + Args: + cfg (omegaconf.DictConfig): configration object + + Returns: + a :class:`~fairseq.criterions.FairseqCriterion` instance + """ + from fairseq import criterions + + return criterions.build_criterion(cfg, self) + + def build_generator( + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None + ): + if getattr(args, "score_reference", False): + from fairseq.sequence_scorer import SequenceScorer + + return SequenceScorer( + self.target_dictionary, + compute_alignment=getattr(args, "print_alignment", False), + ) + + from fairseq.sequence_generator import ( + SequenceGenerator, + SequenceGeneratorWithAlignment, + ) + + # Choose search strategy. Defaults to Beam Search. + sampling = getattr(args, "sampling", False) + sampling_topk = getattr(args, "sampling_topk", -1) + sampling_topp = getattr(args, "sampling_topp", -1.0) + diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) + diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) + match_source_len = getattr(args, "match_source_len", False) + diversity_rate = getattr(args, "diversity_rate", -1) + constrained = getattr(args, "constraints", False) + prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) + if ( + sum( + int(cond) + for cond in [ + sampling, + diverse_beam_groups > 0, + match_source_len, + diversity_rate > 0, + ] + ) + > 1 + ): + raise ValueError("Provided Search parameters are mutually exclusive.") + assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" + assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" + + if sampling: + search_strategy = search.Sampling( + self.target_dictionary, sampling_topk, sampling_topp + ) + elif diverse_beam_groups > 0: + search_strategy = search.DiverseBeamSearch( + self.target_dictionary, diverse_beam_groups, diverse_beam_strength + ) + elif match_source_len: + # this is useful for tagging applications where the output + # length should match the input length, so we hardcode the + # length constraints for simplicity + search_strategy = search.LengthConstrainedBeamSearch( + self.target_dictionary, + min_len_a=1, + min_len_b=0, + max_len_a=1, + max_len_b=0, + ) + elif diversity_rate > -1: + search_strategy = search.DiverseSiblingsSearch( + self.target_dictionary, diversity_rate + ) + elif constrained: + search_strategy = search.LexicallyConstrainedBeamSearch( + self.target_dictionary, args.constraints + ) + elif prefix_allowed_tokens_fn: + search_strategy = search.PrefixConstrainedBeamSearch( + self.target_dictionary, prefix_allowed_tokens_fn + ) + else: + search_strategy = search.BeamSearch(self.target_dictionary) + + if seq_gen_cls is None: + if getattr(args, "print_alignment", False): + seq_gen_cls = SequenceGeneratorWithAlignment + else: + seq_gen_cls = SequenceGenerator + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + return seq_gen_cls( + models, + self.target_dictionary, + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + search_strategy=search_strategy, + **extra_gen_cls_kwargs, + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + """ + Do forward and backward, and return the loss as computed by *criterion* + for the given *model* and *sample*. + + Args: + sample (dict): the mini-batch. The format is defined by the + :class:`~fairseq.data.FairseqDataset`. + model (~fairseq.models.BaseFairseqModel): the model + criterion (~fairseq.criterions.FairseqCriterion): the criterion + optimizer (~fairseq.optim.FairseqOptimizer): the optimizer + update_num (int): the current update + ignore_grad (bool): multiply loss by 0 if this is set to True + + Returns: + tuple: + - the loss + - the sample size, which is used as the denominator for the + gradient + - logging outputs to display while training + """ + model.train() + model.set_num_updates(update_num) + with torch.autograd.profiler.record_function("forward"): + loss, sample_size, logging_output = criterion(model, sample) + if ignore_grad: + loss *= 0 + with torch.autograd.profiler.record_function("backward"): + optimizer.backward(loss) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + loss, sample_size, logging_output = criterion(model, sample) + return loss, sample_size, logging_output + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, constraints=constraints + ) + + def begin_epoch(self, epoch, model): + """Hook function called before the start of each epoch.""" + pass + + def begin_valid_epoch(self, epoch, model): + """Hook function called before the start of each validation epoch.""" + pass + + def aggregate_logging_outputs(self, logging_outputs, criterion): + """[deprecated] Aggregate logging outputs from data parallel training.""" + utils.deprecation_warning( + "The aggregate_logging_outputs API is deprecated. " + "Please use the reduce_metrics API instead." + ) + with metrics.aggregate() as agg: + self.reduce_metrics(logging_outputs, criterion) + return agg.get_smoothed_values() + + def reduce_metrics(self, logging_outputs, criterion): + """Aggregate logging outputs from data parallel training.""" + # backward compatibility for tasks that override aggregate_logging_outputs + base_func = FairseqTask.aggregate_logging_outputs + self_func = getattr(self, "aggregate_logging_outputs").__func__ + if self_func is not base_func: + utils.deprecation_warning( + "Tasks should implement the reduce_metrics API. " + "Falling back to deprecated aggregate_logging_outputs API." + ) + agg_logging_outputs = self.aggregate_logging_outputs( + logging_outputs, criterion + ) + for k, v in agg_logging_outputs.items(): + metrics.log_scalar(k, v) + return + + if not any("ntokens" in log for log in logging_outputs): + warnings.warn( + "ntokens not found in Criterion logging outputs, cannot log wpb or wps" + ) + else: + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + metrics.log_scalar("wpb", ntokens, priority=180, round=1) + metrics.log_speed("wps", ntokens, priority=90, round=1) + + if not any("nsentences" in log for log in logging_outputs): + warnings.warn( + "nsentences not found in Criterion logging outputs, cannot log bsz" + ) + else: + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + metrics.log_scalar("bsz", nsentences, priority=190, round=1) + + criterion.__class__.reduce_metrics(logging_outputs) + + def max_positions(self): + """Return the max input length allowed by the task.""" + return None + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary` (if applicable + for this task).""" + raise NotImplementedError + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary` (if applicable + for this task).""" + raise NotImplementedError + + def build_tokenizer(self, args): + """Build the pre-tokenizer for this task.""" + return encoders.build_tokenizer(args) + + def build_bpe(self, args): + """Build the tokenizer for this task.""" + return encoders.build_bpe(args) + + +class LegacyFairseqTask(FairseqTask): + def __init__(self, args: Namespace): + self.args = args + self.datasets = {} + self.dataset_to_epoch_iter = {} + + @classmethod + def setup_task(cls, args: Namespace, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + return cls(args, **kwargs) + + def has_sharded_data(self, split): + return os.pathsep in getattr(self.args, "data", "") + + def build_model(self, args: Namespace): + """ + Build the :class:`~fairseq.models.BaseFairseqModel` instance for this + task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.models.BaseFairseqModel` instance + """ + from fairseq import models, quantization_utils + + model = models.build_model(args, self) + if getattr(args, "tpu", False): + model.prepare_for_tpu_() + model = quantization_utils.quantize_model_scalar(model, args) + return model + + def build_criterion(self, args: Namespace): + """ + Build the :class:`~fairseq.criterions.FairseqCriterion` instance for + this task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.criterions.FairseqCriterion` instance + """ + from fairseq import criterions + + return criterions.build_criterion(args, self) diff --git a/fairseq/token_generation_constraints.py b/fairseq/token_generation_constraints.py new file mode 100644 index 0000000..e708dc5 --- /dev/null +++ b/fairseq/token_generation_constraints.py @@ -0,0 +1,506 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Implements tracking of constraints for a beam item. + +A list of constraints is given as a list of one or more token +sequences, each of length at least one token. For example, for an input sentence + +> Die maschinelle Übersetzung ist schwer zu kontrollieren. + +We could have the constraints: +* to influence +* hard + +There are two implementations: +* OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. +* UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. + +The difference is that in the first, the constraints are assumed to be +in order; the algorithm will permit zero or more tokens between them. +In the second, the constraints are not ordered, so many orderings will +be explored. + +The same sequence can be present any number of times, and will appear +that many times in the output. +""" + +from collections import Counter +from typing import List, Optional, Set, Tuple + +import torch + + +class ConstraintState: + def __init__(self): + pass + + +def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: + """Takes a list of list of constraints in tensor form (a list of + tensor constraints for each sentence) and transforms it into a + packed Tensor. For example, here is a batch of size 3 with 3, 0, + and 1 constraints: + + [ [ [3 1 2], [3], [4 5 6 7], ] + [], + [ [1 8 9 10 1 4 11 12], ] + ] + + Its corresponding packed structure is: + + [ [ 3 3 1 2 0 3 0 4 5 6 7 0], + [ 0 0 0 0 0 0 0 0 0 0 0 0], + [ 1 1 8 9 10 1 4 11 12 0 0 0] ] + + The packed tensor has shape (batch size, maxlen), where + maxlen is defined below. Each row contains concatenated + constraint tokens for that sentence, with 0 appended after + each constraint. The first item in each row is the number + of constraints for that sentence. So maxlen is the maximum + of + + (number of constraints) + (sum length of constraints) + 1. + + across all sentences in the batch. + """ + # The maximum word length of concatenated constraints for any sentence + max_constraints_len = 1 + for sentence_constraints in batch_constraints: + if len(sentence_constraints): + # number of constraints, plus sum of constrain lens, plus a zero after each + constraints_len = ( + 1 + + sum([c.size(0) for c in sentence_constraints]) + + len(sentence_constraints) + ) + max_constraints_len = max(max_constraints_len, constraints_len) + + batch_size = len(batch_constraints) + constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() + for i, sentence_constraints in enumerate(batch_constraints): + constraints_tensor[i, 0] = len(sentence_constraints) + offset = 1 + for j, constraint in enumerate(sentence_constraints): + this_len = constraint.size(0) + constraints_tensor[i, offset : offset + this_len] = constraint + offset += this_len + 1 + + return constraints_tensor.long() + + +def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: + """ + Transforms *one row* of a packed constraint tensor (e.g., for one + sentence in the batch) into a list of constraint tensors. + """ + constraint_list = [] + num_constraints = constraint_tensor[0] + constraints = constraint_tensor.tolist() + offset = 1 + for i in range(num_constraints): + where = constraints.index(0, offset) + constraint_list.append(constraint_tensor[offset:where]) + offset = where + 1 + + return constraint_list + + +class ConstraintNode: + """ + Represents a node in a trie managing unordered constraints. + """ + + def __init__(self, token: int = None, parent=None): + # The token associate with this node (None for the root) + self.token = int(token) if token is not None else None + # The parent (None at the root) + self.parent = parent + # Whether this node is a completed constraint + self.terminal = 0 + # List of child nodes + self.children = {} + + # The cumulative number of constraints from this point in the + # trie forward + self.num_constraints = 0 + + @property + def id(self): + return self.token + + def __str__(self): + term = self.terminal != 0 + return f"[{self.token}].{term}#{self.num_constraints}" + + def __getitem__(self, key: int): + return self.children.get(key, None) + + def next_tokens(self) -> Set[int]: + """The set of child labels.""" + return set(self.children.keys()) + + @staticmethod + def create(constraints: List[List[int]]): + root = ConstraintNode() + for sequence in constraints: + root.add_sequence(sequence) + + return root + + @staticmethod + def print_graph(node: "ConstraintNode"): + if len(node.children) == 0: + return str(node) + else: + s = f"({node}" + for child in node.children.values(): + s += " " + ConstraintNode.print_graph(child) + s += ")" + return s + + def token_counts(self) -> Counter: + """Returns a counter of the number of times each token is used + in a constraint. + """ + token_counts = Counter() + kids = list(self.children.values()) + while len(kids) > 0: + kid = kids.pop() + token_counts[kid.id] += kid.num_constraints + kids += list(kid.children.values()) + + return token_counts + + def tokens(self) -> Set[int]: + """Returns the set of tokens in constraints.""" + return set(self.token_counts().keys()) + + def add_sequence(self, sequence: List[int]): + """Adds a constraint, represented as a list of integers, to + the trie.""" + assert len(sequence) > 0 + + token = int(sequence[0]) + if token not in self.children: + self.children[token] = ConstraintNode(token, parent=self) + + node = self.children[token] + if len(sequence) == 1: + node.terminal += 1 + node.num_constraints += 1 + parent = node.parent + while parent is not None: + parent.num_constraints += 1 + parent = parent.parent + else: + node.add_sequence(sequence[1:]) + + +class UnorderedConstraintState(ConstraintState): + """ + Records progress through the set of constraints for each item in the beam + using a trie. + """ + + def __init__(self, node: ConstraintNode, copy_from: "ConstraintState" = None): + self.node = node + + if copy_from is None: + # The root node + self.root = node + # The set of states in the graph that have been completed + self.completed = Counter() + # The... + self.generated = Counter() + # The list of tokens we need to generate + self.needed_tokens = self.root.tokens() + else: + self.completed = Counter(copy_from.completed) + self.generated = Counter(copy_from.generated) + self.root = copy_from.root + + # Mark the node as generated + if self.node != self.root: + self.generated[node] += 1 + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + constraint_trie_root = ConstraintNode.create(constraint_list) + return UnorderedConstraintState(constraint_trie_root) + + def __str__(self): + gen_str = ",".join([str(node) for node in self.generated]) + return f"{self.name}/{self.bank}({gen_str})x{self.num_completed}" + + def __copy__(self): + copied_state = UnorderedConstraintState(self.node, copy_from=self) + return copied_state + + def copy(self): + return self.__copy__() + + @property + def name(self): + if self.node.id is None: + return "ROOT" + else: + return str(self.node.id) + + @property + def is_root(self): + return self.node == self.root + + @property + def bank(self): + return sum(self.generated.values()) + + @property + def num_completed(self): + """The number of constraints (not constraint tokens) that are completed. + In addition to the already-completed states, we need to account for the + current state, which might get marked as completed when another token + is generated. + """ + in_final = self.node.terminal and self.completed[self.node] < self.node.terminal + return sum(self.completed.values()) + in_final + + @property + def finished(self): + return self.root.num_constraints - self.num_completed == 0 + + @property + def token_counts(self): + return self.root.token_counts() + + @property + def tokens(self): + return self.root.tokens() + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + if self.node != self.root: + return self.root.next_tokens().union(self.node.next_tokens()) + else: + return self.root.next_tokens() + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + + next_state = None + child = self.node[token] + if child is not None and self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + + def rewind(): + """If we're mid-trie and an "illegal" token is chosen next, we need + to reset our state to the root state. However, along the way, we need + to check whether a prefix of the current trie state represents a state + we could mark as completed. + """ + node = self.node + while node != self.root: + if node.terminal and self.completed[node] < node.terminal: + next_state.completed[node] += 1 + return + + next_state.generated[node] -= 1 + node = node.parent + + # Fall off the graph, check the root + if next_state is None and token in self.root.next_tokens(): + child = self.root[token] + # We can only traverse this edge if it's not saturated + if self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + else: + next_state = UnorderedConstraintState(self.root, copy_from=self) + + # Rewind + rewind() + + elif next_state is None: + next_state = UnorderedConstraintState(self.root, copy_from=self) + # Rewind + rewind() + + return next_state + + +class ConstraintSequence: + def __init__(self, sequences: List[List[int]]): + """Represents a set of possibly multitoken constraints by + concatenating them and internally recording the end points. + """ + self.sequences = [] + self.endpoints = [] + self.num_tokens = 0 + self.tokens = set() + for sequence in sequences: + for token in sequence: + self.tokens.add(token) + self.num_tokens += len(sequence) + self.endpoints += [False for x in range(len(sequence) - 1)] + [True] + self.sequences += sequence + + def __getitem__(self, key: int): + return self.sequences[key] + + def __len__(self): + return len(self.sequences) + + def __str__(self): + return str(self.sequences) + + +class OrderedConstraintState(ConstraintState): + """ + Records progress through the set of linear nonbranching constraints with gaps. + """ + + def __init__(self, sequence: ConstraintSequence, state: int = -1): + self.sequence = sequence + self.state = state + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + return OrderedConstraintState(ConstraintSequence(constraint_list), -1) + + def __str__(self): + return f"{self.state}/{self.bank}x{self.num_completed}" + + def __copy__(self): + return OrderedConstraintState(self.sequence, self.state) + + def copy(self): + return self.__copy__() + + @property + def num_completed(self): + if self.state == -1: + return 0 + count = len( + list(filter(lambda x: x, self.sequence.endpoints[0 : self.state + 1])) + ) + return count + + @property + def is_root(self): + return self.state == -1 + + @property + def name(self): + if self.state == -1: + return "ROOT" + else: + return str(self.sequence[self.state]) + + @property + def bank(self) -> int: + return self.state + 1 + + @property + def finished(self): + return self.state + 1 == len(self.sequence) + + @property + def token_counts(self): + return self.sequence.token_counts() + + @property + def tokens(self): + return self.sequence.tokens + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + tokens = set() + if self.state > 0: + tokens.add(self.sequence[0]) + if not self.finished: + tokens.add(self.sequence[self.state + 1]) + return tokens + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + # print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") + + if self.finished: + # Accept anything + next_state = self.copy() + + elif self.sequence[self.state + 1] == token: + # Advance to the next token + next_state = OrderedConstraintState(self.sequence, self.state + 1) + + elif self.sequence.endpoints[self.state]: + # Accept anything between constraints (*) + next_state = self.copy() + + elif token == self.sequence[0]: + # Start over having generated the first token + next_state = OrderedConstraintState(self.sequence, 0) + else: + # Start over from the root + next_state = OrderedConstraintState(self.sequence, -1) + + return next_state diff --git a/fairseq/tokenizer.py b/fairseq/tokenizer.py new file mode 100644 index 0000000..42131f7 --- /dev/null +++ b/fairseq/tokenizer.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re + + +SPACE_NORMALIZER = re.compile(r"\s+") + + +def tokenize_line(line): + line = SPACE_NORMALIZER.sub(" ", line) + line = line.strip() + return line.split() diff --git a/fairseq/trainer.py b/fairseq/trainer.py new file mode 100644 index 0000000..a4d273c --- /dev/null +++ b/fairseq/trainer.py @@ -0,0 +1,1149 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Train a network across multiple GPUs. +""" + +import contextlib +import logging +import sys +import time +from argparse import Namespace +from itertools import chain +from typing import Any, Dict, List + +import torch +from fairseq import checkpoint_utils, distributed_utils, models, optim, utils +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.file_io import PathManager +from fairseq.logging import meters, metrics +from fairseq.nan_detector import NanDetector +from fairseq.optim import lr_scheduler + + +logger = logging.getLogger(__name__) + + +class Trainer(object): + """Main class for data parallel training. + + This class supports synchronous distributed data parallel training, + where multiple workers each have a full model replica and gradients + are accumulated across workers before each update. We use + :class:`~torch.nn.parallel.DistributedDataParallel` to handle + communication of the gradients across workers. + """ + + def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): + + if isinstance(cfg, Namespace): + logger.warning( + "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf" + ) + cfg = convert_namespace_to_omegaconf(cfg) + + self.cfg = cfg + self.task = task + + # catalog shared parameters + shared_params = _catalog_shared_params(model) + self.tpu = cfg.common.tpu + self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu + if self.cuda: + self.device = torch.device("cuda") + elif self.tpu: + self.device = utils.get_tpu_device() + else: + self.device = torch.device("cpu") + + # copy model and criterion to current device/dtype + self._criterion = criterion + self._model = model + if self.tpu: + import torch_xla.core.xla_model as xm + + self._model = xm.send_cpu_data_to_device(self._model, self.device) + if cfg.common.fp16: + self._criterion = self._criterion.half() + self._model = self._model.half() + elif cfg.common.bf16: + self._criterion = self._criterion.to(dtype=torch.bfloat16) + self._model = self._model.to(dtype=torch.bfloat16) + if not cfg.distributed_training.pipeline_model_parallel: + self._criterion = self._criterion.to(device=self.device) + self._model = self._model.to(device=self.device) + self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel + self.last_device = None + if self.cuda and self.pipeline_model_parallel: + self.last_device = torch.device( + cfg.distributed_training.pipeline_devices[-1] + ) + + # check that shared parameters are preserved after device transfer + for shared_param in shared_params: + ref = _get_module_by_path(self._model, shared_param[0]) + for path in shared_param[1:]: + logger.info( + "detected shared parameter: {} <- {}".format(shared_param[0], path) + ) + _set_module_by_path(self._model, path, ref) + + self._dummy_batch = None # indicates we don't have a dummy batch at first + self._lr_scheduler = None + self._num_updates = 0 + self._num_xla_compiles = 0 # for TPUs + self._optim_history = None + self._optimizer = None + self._warn_once = set() + self._wrapped_criterion = None + self._wrapped_model = None + + # TODO(myleott): support tpu + if self.cuda and self.data_parallel_world_size > 1: + self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size) + else: + self._grad_norm_buf = None + + self.quantizer = quantizer + if self.quantizer is not None: + self.quantizer.set_trainer(self) + + # get detailed cuda environment + if self.cuda: + self.cuda_env = utils.CudaEnvironment() + if self.data_parallel_world_size > 1: + self.cuda_env_arr = distributed_utils.all_gather_list(self.cuda_env) + else: + self.cuda_env_arr = [self.cuda_env] + if self.data_parallel_rank == 0: + utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr) + else: + self.cuda_env = None + self.cuda_env_arr = None + + metrics.log_start_time("wall", priority=790, round=0) + + self._start_time = time.time() + self._previous_training_time = 0 + self._cumulative_training_time = None + + def reinitialize(self): + """Reinitialize the Trainer, typically after model params change.""" + self._lr_scheduler = None + self._optimizer = None + self._wrapped_criterion = None + self._wrapped_model = None + + @property + def data_parallel_world_size(self): + return self.cfg.distributed_training.distributed_world_size + + @property + def data_parallel_process_group(self): + if self.tpu: + return ("tpu", None) + else: + return None + + @property + def data_parallel_rank(self): + return self.cfg.distributed_training.distributed_rank + + @property + def is_data_parallel_master(self): + return distributed_utils.is_master(self.cfg.distributed_training) + + @property + def criterion(self): + if self._wrapped_criterion is None: + if ( + utils.has_parameters(self._criterion) + and self.data_parallel_world_size > 1 + and not self.cfg.optimization.use_bmuf + and not self.tpu + ): + self._wrapped_criterion = models.DistributedFairseqModel( + self.cfg.distributed_training, + self._criterion, + process_group=self.data_parallel_process_group, + ) + else: + self._wrapped_criterion = self._criterion + return self._wrapped_criterion + + @property + def model(self): + if self._wrapped_model is None: + if ( + self.data_parallel_world_size > 1 + and not self.cfg.optimization.use_bmuf + and not self.tpu + ): + self._wrapped_model = models.DistributedFairseqModel( + self.cfg.distributed_training, + self._model, + process_group=self.data_parallel_process_group, + ) + else: + self._wrapped_model = self._model + return self._wrapped_model + + @property + def optimizer(self): + if self._optimizer is None: + self._build_optimizer() + return self._optimizer + + @property + def lr_scheduler(self): + if self._lr_scheduler is None: + self._build_optimizer() # this will initialize self._lr_scheduler + return self._lr_scheduler + + def _build_optimizer(self): + params = list( + filter( + lambda p: p.requires_grad, + chain(self.model.parameters(), self.criterion.parameters()), + ) + ) + + if self.cfg.common.fp16 or self.cfg.common.bf16: + if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: + logger.info( + "NOTE: your device does NOT support faster training with --fp16, " + "please switch to FP32 which is likely to be faster" + ) + if ( + self.cfg.common.memory_efficient_fp16 + or self.cfg.common.memory_efficient_bf16 + ): + self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( + self.cfg, params + ) + else: + self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params) + else: + if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: + logger.info("NOTE: your device may support faster training with --fp16") + self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) + + if self.cfg.optimization.use_bmuf: + self._optimizer = optim.FairseqBMUF( + self.cfg.bmuf, + self._optimizer, + ) + + if self.cfg.distributed_training.zero_sharding == "os": + if ( + self.cfg.common.fp16 + and not self.cfg.common.memory_efficient_fp16 + and not self.cfg.common.memory_efficient_bf16 + ) and not self.cfg.common.fp16_no_flatten_grads: + raise ValueError( + "ZeRO is incomptabile with fp16 and flattened grads. " + "Please use --fp16-no-flatten-grads" + ) + else: + optim.shard_(self._optimizer, self.data_parallel_process_group) + + # We should initialize the learning rate scheduler immediately after + # building the optimizer, so that the initial learning rate is set. + self._lr_scheduler = lr_scheduler.build_lr_scheduler( + self.cfg.lr_scheduler, + self.optimizer, + ) + self._lr_scheduler.step_update(0) + + def consolidate_optimizer(self): + """For OSS, we need to consolidate the state dict.""" + if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): + self.optimizer.optimizer.consolidate_state_dict() + + def save_checkpoint(self, filename, extra_state): + """Save all training state in a checkpoint file.""" + if self.is_data_parallel_master: # only save one checkpoint + extra_state["metrics"] = metrics.state_dict() + extra_state["previous_training_time"] = self.cumulative_training_time() + checkpoint_utils.save_state( + filename, + self.cfg, + self.get_model().state_dict(), + self.get_criterion(), + self.optimizer, + self.lr_scheduler, + self.get_num_updates(), + self._optim_history, + extra_state, + ) + + def load_checkpoint( + self, + filename, + reset_optimizer=False, + reset_lr_scheduler=False, + optimizer_overrides=None, + reset_meters=False, + ): + """Load all training state from a checkpoint file.""" + extra_state, self._optim_history, last_optim_state = None, [], None + + bexists = PathManager.isfile(filename) + if bexists: + state = checkpoint_utils.load_checkpoint_to_cpu(filename) + # load model parameters + try: + self.get_model().load_state_dict( + state["model"], strict=True, model_cfg=self.cfg.model + ) + if utils.has_parameters(self.get_criterion()): + self.get_criterion().load_state_dict( + state["criterion"], strict=True + ) + except Exception: + raise Exception( + "Cannot load model parameters from checkpoint {}; " + "please ensure that the architectures match.".format(filename) + ) + + extra_state = state["extra_state"] + self._optim_history = state["optimizer_history"] + last_optim_state = state.get("last_optimizer_state", None) + + if last_optim_state is not None and not reset_optimizer: + # rebuild optimizer after loading model, since params may have changed + self._build_optimizer() + + # only reload optimizer and lr_scheduler if they match + last_optim = self._optim_history[-1] + assert ( + last_optim["criterion_name"] == self.get_criterion().__class__.__name__ + ), "Criterion does not match; please reset the optimizer (--reset-optimizer)." + assert ( + last_optim["optimizer_name"] == self.optimizer.__class__.__name__ + ), "Optimizer does not match; please reset the optimizer (--reset-optimizer)." + + if not reset_lr_scheduler: + self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) + self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) + + self.set_num_updates(last_optim["num_updates"]) + + if extra_state is not None: + epoch = extra_state["train_iterator"]["epoch"] + logger.info( + "loaded checkpoint {} (epoch {} @ {} updates)".format( + filename, epoch, self.get_num_updates() + ) + ) + + if "previous_training_time" in extra_state: + self._previous_training_time = extra_state["previous_training_time"] + self._start_time = time.time() + + self.lr_step(epoch) + + if "metrics" in extra_state and not reset_meters: + metrics.load_state_dict(extra_state["metrics"]) + + # reset TimeMeters, since their start times don't make sense anymore + for meter in metrics.get_meters("default"): + if isinstance(meter, meters.TimeMeter): + meter.reset() + else: + logger.info("no existing checkpoint found {}".format(filename)) + + return extra_state + + def get_train_iterator( + self, + epoch, + combine=True, + load_dataset=True, + data_selector=None, + shard_batch_itr=True, + disable_iterator_cache=False, + ): + """Return an EpochBatchIterator over the training set for a given epoch.""" + if load_dataset: + logger.info("loading train data for epoch {}".format(epoch)) + self.task.load_dataset( + self.cfg.dataset.train_subset, + epoch=epoch, + combine=combine, + data_selector=data_selector, + ) + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.dataset(self.cfg.dataset.train_subset), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + self.task.max_positions(), + self.model.max_positions(), + self.cfg.dataset.max_tokens, + ), + ignore_invalid_inputs=True, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, + num_shards=self.data_parallel_world_size if shard_batch_itr else 1, + shard_id=self.data_parallel_rank if shard_batch_itr else 0, + num_workers=self.cfg.dataset.num_workers, + epoch=epoch, + data_buffer_size=self.cfg.dataset.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + ) + self.reset_dummy_batch(batch_iterator.first_batch) + return batch_iterator + + def get_valid_iterator( + self, + subset, + disable_iterator_cache=False, + ): + """Return an EpochBatchIterator over given validation subset for a given epoch.""" + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.dataset(subset), + max_tokens=self.cfg.dataset.max_tokens_valid, + max_sentences=self.cfg.dataset.batch_size_valid, + max_positions=utils.resolve_max_positions( + self.task.max_positions(), + self.model.max_positions(), + ), + ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, + num_shards=self.data_parallel_world_size, + shard_id=self.data_parallel_rank, + num_workers=self.cfg.dataset.num_workers, + data_buffer_size=self.cfg.dataset.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + ) + self.reset_dummy_batch(batch_iterator.first_batch) + return batch_iterator + + def begin_epoch(self, epoch): + """Called at the beginning of each epoch.""" + logger.info("begin training epoch {}".format(epoch)) + + if self.quantizer is not None: + self.quantizer.begin_epoch(epoch) + + # task specific setup per epoch + self.task.begin_epoch(epoch, self.get_model()) + + if self.tpu: + import torch_xla.core.xla_model as xm + + xm.rendezvous("begin_epoch") # wait for all workers + xm.mark_step() + + def begin_valid_epoch(self, epoch): + """Called at the beginning of each validation epoch.""" + + # task specific setup per validation epoch + self.task.begin_valid_epoch(epoch, self.get_model()) + + def reset_dummy_batch(self, batch): + self._dummy_batch = batch + + @metrics.aggregate("train") + def train_step(self, samples, raise_oom=False): + """Do forward, backward and parameter update.""" + self._set_seed() + self.model.train() + self.criterion.train() + self.zero_grad() + + metrics.log_start_time("train_wall", priority=800, round=0) + + # forward and backward pass + logging_outputs, sample_size, ooms = [], 0, 0 + for i, sample in enumerate(samples): + sample = self._prepare_sample(sample) + if sample is None: + # when sample is None, run forward/backward on a dummy batch + # and ignore the resulting gradients + sample = self._prepare_sample(self._dummy_batch) + is_dummy_batch = True + else: + if self._dummy_batch == "DUMMY": + self._dummy_batch = sample + is_dummy_batch = False + + def maybe_no_sync(): + """ + Whenever *samples* contains more than one mini-batch, we + want to accumulate gradients locally and only call + all-reduce in the last backwards pass. + """ + if ( + self.data_parallel_world_size > 1 + and hasattr(self.model, "no_sync") + and i < len(samples) - 1 + ): + return self.model.no_sync() + else: + return contextlib.ExitStack() # dummy contextmanager + + try: + with maybe_no_sync(): + # forward and backward + loss, sample_size_i, logging_output = self.task.train_step( + sample=sample, + model=self.model, + criterion=self.criterion, + optimizer=self.optimizer, + update_num=self.get_num_updates(), + ignore_grad=is_dummy_batch, + ) + del loss + + logging_outputs.append(logging_output) + sample_size += sample_size_i + + # emptying the CUDA cache after the first step can + # reduce the chance of OOM + if self.cuda and self.get_num_updates() == 0: + torch.cuda.empty_cache() + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + if raise_oom: + raise e + logger.warning( + "attempting to recover from OOM in forward/backward pass" + ) + ooms += 1 + self.zero_grad() + if self.cuda: + torch.cuda.empty_cache() + if self.cfg.distributed_training.distributed_world_size == 1: + return None + else: + raise e + + if self.tpu and i < len(samples) - 1: + # tpu-comment: every XLA operation before marking step is + # appended to the IR graph, and processing too many batches + # before marking step can lead to OOM errors. + # To handle gradient accumulation use case, we explicitly + # mark step here for every forward pass without a backward pass + import torch_xla.core.xla_model as xm + + xm.mark_step() + + if is_dummy_batch: + if torch.is_tensor(sample_size): + sample_size.zero_() + else: + sample_size *= 0.0 + + if torch.is_tensor(sample_size): + sample_size = sample_size.float() + else: + sample_size = float(sample_size) + + # gather logging outputs from all replicas + if self._sync_stats(): + train_time = self._local_cumulative_training_time() + logging_outputs, ( + sample_size, + ooms, + total_train_time, + ) = self._aggregate_logging_outputs( + logging_outputs, + sample_size, + ooms, + train_time, + ignore=is_dummy_batch, + ) + self._cumulative_training_time = ( + total_train_time / self.data_parallel_world_size + ) + + if hasattr(self.model, "all_reduce"): + self.model.all_reduce() + + overflow = False + try: + if self.tpu and self.data_parallel_world_size > 1: + import torch_xla.core.xla_model as xm + + gradients = xm._fetch_gradients(self.optimizer.optimizer) + xm.all_reduce( + "sum", gradients, scale=1.0 / self.data_parallel_world_size + ) + + with torch.autograd.profiler.record_function("multiply-grads"): + # multiply gradients by (# GPUs / sample_size) since DDP + # already normalizes by the number of GPUs. Thus we get + # (sum_of_gradients / sample_size). + if not self.cfg.optimization.use_bmuf: + self.optimizer.multiply_grads( + self.data_parallel_world_size / sample_size + ) + elif sample_size > 0: # BMUF needs to check sample size + num = self.data_parallel_world_size if self._sync_stats() else 1 + self.optimizer.multiply_grads(num / sample_size) + + with torch.autograd.profiler.record_function("clip-grads"): + # clip grads + grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm) + + # check that grad norms are consistent across workers + if ( + not self.cfg.optimization.use_bmuf + and self.cfg.distributed_training.distributed_wrapper != "SlowMo" + and not self.tpu + ): + self._check_grad_norms(grad_norm) + + with torch.autograd.profiler.record_function("optimizer"): + # take an optimization step + self.optimizer.step() + + except FloatingPointError: + # re-run the forward and backward pass with hooks attached to print + # out where it fails + with NanDetector(self.get_model()): + self.task.train_step( + sample, + self.model, + self.criterion, + self.optimizer, + self.get_num_updates(), + ignore_grad=False, + ) + raise + except OverflowError as e: + overflow = True + logger.info("NOTE: overflow detected, " + str(e)) + grad_norm = torch.tensor(0.0).cuda() + self.zero_grad() + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + logger.error("OOM during optimization, irrecoverable") + raise e + + # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step + if hasattr(self.model, "perform_additional_optimizer_actions"): + if hasattr(self.optimizer, "fp32_params"): + self.model.perform_additional_optimizer_actions( + self.optimizer.optimizer, self.optimizer.fp32_params + ) + else: + self.model.perform_additional_optimizer_actions( + self.optimizer.optimizer + ) + + if ( + not overflow + or self.cfg.distributed_training.distributed_wrapper == "SlowMo" + ): + self.set_num_updates(self.get_num_updates() + 1) + + if self.tpu: + # mark step on TPUs + import torch_xla.core.xla_model as xm + + xm.mark_step() + + # only log stats every log_interval steps + # this causes wps to be misreported when log_interval > 1 + logging_output = {} + if self.get_num_updates() % self.cfg.common.log_interval == 0: + # log memory usage + mem_info = xm.get_memory_info(self.device) + gb_free = mem_info["kb_free"] / 1024 / 1024 + gb_total = mem_info["kb_total"] / 1024 / 1024 + metrics.log_scalar( + "gb_free", + gb_free, + priority=1500, + round=1, + weight=0, + ) + metrics.log_scalar( + "gb_total", + gb_total, + priority=1600, + round=1, + weight=0, + ) + + logging_output = self._reduce_and_log_stats( + logging_outputs, + sample_size, + grad_norm, + ) + + # log whenever there's an XLA compilation, since these + # slow down training and may indicate opportunities for + # optimization + self._check_xla_compilation() + else: + # log stats + logging_output = self._reduce_and_log_stats( + logging_outputs, + sample_size, + grad_norm, + ) + + # clear CUDA cache to reduce memory fragmentation + if ( + self.cuda + and self.cfg.common.empty_cache_freq > 0 + and ( + (self.get_num_updates() + self.cfg.common.empty_cache_freq - 1) + % self.cfg.common.empty_cache_freq + ) + == 0 + ): + torch.cuda.empty_cache() + + if self.cfg.common.fp16: + metrics.log_scalar( + "loss_scale", + self.optimizer.scaler.loss_scale, + priority=700, + round=4, + weight=0, + ) + + metrics.log_stop_time("train_wall") + return logging_output + + @metrics.aggregate("valid") + def valid_step(self, sample, raise_oom=False): + """Do forward pass in evaluation mode.""" + if self.tpu: + import torch_xla.core.xla_model as xm + + xm.rendezvous("valid_step") # wait for all workers + xm.mark_step() + + with torch.no_grad(): + self.model.eval() + self.criterion.eval() + + sample = self._prepare_sample(sample) + if sample is None: + sample = self._prepare_sample(self._dummy_batch) + is_dummy_batch = True + else: + if self._dummy_batch == "DUMMY": + self._dummy_batch = sample + is_dummy_batch = False + + try: + _loss, sample_size, logging_output = self.task.valid_step( + sample, self.model, self.criterion + ) + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + if not raise_oom: + logger.warning( + "ran out of memory in validation step, retrying batch" + ) + for p in self.model.parameters(): + if p.grad is not None: + p.grad = None # free some memory + if self.cuda: + torch.cuda.empty_cache() + return self.valid_step(sample, raise_oom=True) + raise e + + logging_outputs = [logging_output] + if is_dummy_batch: + if torch.is_tensor(sample_size): + sample_size.zero_() + else: + sample_size *= 0.0 + + # gather logging outputs from all replicas + if self.data_parallel_world_size > 1: + logging_outputs, (sample_size,) = self._aggregate_logging_outputs( + logging_outputs, + sample_size, + ignore=is_dummy_batch, + ) + + # log validation stats + logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) + + return logging_output + + def zero_grad(self): + self.optimizer.zero_grad() + + def lr_step(self, epoch, val_loss=None): + """Adjust the learning rate at the end of the epoch.""" + self.lr_scheduler.step(epoch, val_loss) + # prefer updating the LR based on the number of steps + return self.lr_step_update() + + def lr_step_update(self): + """Update the learning rate after each update.""" + new_lr = self.lr_scheduler.step_update(self.get_num_updates()) + metrics.log_scalar("lr", new_lr, weight=0, priority=300) + return new_lr + + def get_lr(self): + """Get the current learning rate.""" + return self.optimizer.get_lr() + + def get_model(self): + """Get the (non-wrapped) model instance.""" + return self._model + + def get_criterion(self): + """Get the (non-wrapped) criterion instance.""" + return self._criterion + + def get_meter(self, name): + """[deprecated] Get a specific meter by name.""" + from fairseq import meters + + if "get_meter" not in self._warn_once: + self._warn_once.add("get_meter") + utils.deprecation_warning( + "Trainer.get_meter is deprecated. Please use fairseq.metrics instead." + ) + + train_meters = metrics.get_meters("train") + if train_meters is None: + train_meters = {} + + if name == "train_loss" and "loss" in train_meters: + return train_meters["loss"] + elif name == "train_nll_loss": + # support for legacy train.py, which assumed this meter is + # always initialized + m = train_meters.get("nll_loss", None) + return m or meters.AverageMeter() + elif name == "wall": + # support for legacy train.py, which assumed this meter is + # always initialized + m = metrics.get_meter("default", "wall") + return m or meters.TimeMeter() + elif name == "wps": + m = metrics.get_meter("train", "wps") + return m or meters.TimeMeter() + elif name in {"valid_loss", "valid_nll_loss"}: + # support for legacy train.py, which assumed these meters + # are always initialized + k = name[len("valid_") :] + m = metrics.get_meter("valid", k) + return m or meters.AverageMeter() + elif name == "oom": + return meters.AverageMeter() + elif name in train_meters: + return train_meters[name] + return None + + def get_num_updates(self): + """Get the number of parameters updates.""" + return self._num_updates + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self._num_updates = num_updates + self.lr_step_update() + if self.quantizer: + self.quantizer.step_update(self._num_updates) + metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) + + def clip_grad_norm(self, clip_norm): + return self.optimizer.clip_grad_norm(clip_norm, aggregate_norm_fn=None) + + def cumulative_training_time(self): + if self._cumulative_training_time is None: + # single GPU + return self._local_cumulative_training_time() + else: + return self._cumulative_training_time + + def _local_cumulative_training_time(self): + """Aggregate training time in seconds.""" + return time.time() - self._start_time + self._previous_training_time + + def _prepare_sample(self, sample): + if sample == "DUMMY": + raise Exception( + "Trying to use an uninitialized 'dummy' batch. This usually indicates " + "that the total number of batches is smaller than the number of " + "participating GPUs. Try reducing the batch size or using fewer GPUs." + ) + + if sample is None or len(sample) == 0: + return None + + if self.cuda: + if self.pipeline_model_parallel: + if "target" in sample: + sample["target"] = utils.move_to_cuda( + sample["target"], device=self.last_device + ) + else: + sample = utils.move_to_cuda(sample) + + def apply_half(t): + if t.dtype is torch.float32: + return t.half() + return t + + def apply_bfloat16(t): + if t.dtype is torch.float32: + return t.to(dtype=torch.bfloat16) + return t + + if self.cfg.common.fp16: + sample = utils.apply_to_sample(apply_half, sample) + + if self.cfg.common.bf16: + sample = utils.apply_to_sample(apply_bfloat16, sample) + + return sample + + def _set_seed(self): + # Set seed based on args.seed and the update number so that we get + # reproducible results when resuming from checkpoints + seed = self.cfg.common.seed + self.get_num_updates() + utils.set_torch_seed(seed) + + def _sync_stats(self): + # Return True if it's using multiple GPUs and DDP or multiple GPUs with + # BMUF and it's a bmuf sync with warmup iterations completed before. + if self.data_parallel_world_size == 1: + return False + elif self.cfg.optimization.use_bmuf: + return ( + self.get_num_updates() + 1 + ) % self.cfg.bmuf.global_sync_iter == 0 and ( + self.get_num_updates() + 1 + ) > self.cfg.bmuf.warmup_iterations + else: + return True + + def _log_oom(self, exc): + msg = "OOM: Ran out of memory with exception: {}".format(exc) + logger.warning(msg) + if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"): + for device_idx in range(torch.cuda.device_count()): + logger.warning(torch.cuda.memory_summary(device=device_idx)) + sys.stderr.flush() + + def _aggregate_logging_outputs( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + ): + if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()): + return self._fast_stat_sync_sum( + logging_outputs, *extra_stats_to_sum, ignore=ignore + ) + else: + return self._all_gather_list_sync( + logging_outputs, *extra_stats_to_sum, ignore=ignore + ) + + def _all_gather_list_sync( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + ): + """ + Sync logging outputs across workers. all_gather_list_sync is + suitable when logging outputs are complex types. + """ + if self.tpu: + raise NotImplementedError + if ignore: + logging_outputs = [] + results = list( + zip( + *distributed_utils.all_gather_list( + [logging_outputs] + list(extra_stats_to_sum), + max_size=getattr(self.cfg.common, "all_gather_list_size", 16384), + group=self.data_parallel_process_group, + ) + ) + ) + logging_outputs, extra_stats_to_sum = results[0], results[1:] + logging_outputs = list(chain.from_iterable(logging_outputs)) + extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] + return logging_outputs, extra_stats_to_sum + + def _fast_stat_sync_sum( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + ): + """ + Sync logging outputs across workers. fast_stat_sync_sum is + faster than all_gather_list_sync, but is only suitable when + logging outputs are scalars and can be summed. Note that + *logging_outputs* cannot contain any nested dicts/lists. + """ + data = {} + for i, stat in enumerate(extra_stats_to_sum): + data["extra_stats_" + str(i)] = stat + if len(logging_outputs) > 0: + log_keys = list(logging_outputs[0].keys()) + for k in log_keys: + if not ignore: + v = sum(log[k] for log in logging_outputs if k in log) + else: + v = logging_outputs[0][k] + v = torch.zeros_like(v) if torch.is_tensor(v) else 0 + data["logging_outputs_" + k] = v + else: + log_keys = None + + data = distributed_utils.all_reduce_dict( + data, device=self.device, group=self.data_parallel_process_group + ) + + extra_stats_to_sum = [ + data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum)) + ] + if log_keys is not None: + logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}] + else: + logging_outputs = [] + return logging_outputs, extra_stats_to_sum + + def _check_grad_norms(self, grad_norm): + """Check that grad norms are consistent across workers.""" + if self._grad_norm_buf is not None: + self._grad_norm_buf.zero_() + self._grad_norm_buf[self.data_parallel_rank] = grad_norm + distributed_utils.all_reduce( + self._grad_norm_buf, group=self.data_parallel_process_group + ) + + def is_consistent(tensor): + max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) + return ( + not torch.isfinite(tensor).any() + or (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + ) + + if not is_consistent(self._grad_norm_buf): + pretty_detail = "\n".join( + "rank {:3d} = {:.8f}".format(r, n) + for r, n in enumerate(self._grad_norm_buf.tolist()) + ) + error_detail = "grad_norm across the workers:\n{}\n".format( + pretty_detail + ) + raise RuntimeError( + "Fatal error: gradients are inconsistent between workers. " + "Try --ddp-backend=no_c10d. " + "Or are you mixing up different generation of GPUs in training?" + + "\n" + + "-" * 80 + + "\n{}\n".format(error_detail) + + "-" * 80 + ) + + def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): + if grad_norm is not None: + metrics.log_speed("ups", 1.0, priority=100, round=2) + metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) + if self.cfg.optimization.clip_norm > 0: + metrics.log_scalar( + "clip", + torch.where( + grad_norm > self.cfg.optimization.clip_norm, + grad_norm.new_tensor(100), + grad_norm.new_tensor(0), + ), + priority=500, + round=1, + ) + + with metrics.aggregate() as agg: + if logging_outputs is not None: + self.task.reduce_metrics(logging_outputs, self.get_criterion()) + del logging_outputs + + # extra warning for criterions that don't properly log a loss value + if "loss" not in agg: + if "loss" not in self._warn_once: + self._warn_once.add("loss") + logger.warning( + "Criterion.reduce_metrics did not log a 'loss' value, " + "which may break some functionality" + ) + metrics.log_scalar("loss", -1) + + # support legacy interface + if self.tpu: + logging_output = {} + else: + logging_output = agg.get_smoothed_values() + logging_output["sample_size"] = sample_size + for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: + if key_to_delete in logging_output: + del logging_output[key_to_delete] + return logging_output + + def _check_xla_compilation(self): + import torch_xla.debug.metrics as met + + compile_stats = met.metric_data("CompileTime") + if compile_stats is None: + return + num_xla_compiles = compile_stats[0] + if num_xla_compiles > self._num_xla_compiles: + logger.warning( + "XLA compilation detected on device #{}; too many of these can lead " + "to slow training, but we expect a few in the beginning".format( + self.cfg.distributed_training.distributed_rank + ) + ) + self._num_xla_compiles = num_xla_compiles + + +def _catalog_shared_params(module, memo=None, prefix=""): + if memo is None: + first_call = True + memo = {} + else: + first_call = False + for name, param in module._parameters.items(): + param_prefix = prefix + ("." if prefix else "") + name + if param not in memo: + memo[param] = [] + memo[param].append(param_prefix) + for name, m in module._modules.items(): + if m is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + _catalog_shared_params(m, memo, submodule_prefix) + if first_call: + return [x for x in memo.values() if len(x) > 1] + + +def _get_module_by_path(module, path): + path = path.split(".") + for name in path: + module = getattr(module, name) + return module + + +def _set_module_by_path(module, path, value): + path = path.split(".") + for name in path[:-1]: + module = getattr(module, name) + setattr(module, path[-1], value) diff --git a/fairseq/utils.py b/fairseq/utils.py new file mode 100644 index 0000000..87c124b --- /dev/null +++ b/fairseq/utils.py @@ -0,0 +1,720 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import contextlib +import copy +import importlib +import logging +import os +import sys +import tempfile +import warnings +from itertools import accumulate +from typing import Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F +from fairseq.data import iterators +from fairseq.file_io import PathManager +from fairseq.logging.meters import safe_round +from fairseq.modules import gelu, gelu_accurate +from fairseq.modules.multihead_attention import MultiheadAttention +from torch import Tensor + + +try: + from amp_C import multi_tensor_l2norm + + multi_tensor_l2norm_available = True +except ImportError: + multi_tensor_l2norm_available = False + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +logger = logging.getLogger(__name__) + + +MANIFOLD_PATH_SEP = "|" + + +class FileContentsAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(FileContentsAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + if PathManager.isfile(values): + with PathManager.open(values) as f: + argument = f.read().strip() + else: + argument = values + setattr(namespace, self.dest, argument) + + +def split_paths(paths: str) -> List[str]: + return ( + paths.split(os.pathsep) + if "://" not in paths + else paths.split(MANIFOLD_PATH_SEP) + ) + + +def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): + from fairseq import checkpoint_utils + + deprecation_warning( + "utils.load_ensemble_for_inference is deprecated. " + "Please use checkpoint_utils.load_model_ensemble instead." + ) + return checkpoint_utils.load_model_ensemble( + filenames, arg_overrides=model_arg_overrides, task=task + ) + + +def apply_to_sample(f, sample): + if hasattr(sample, "__len__") and len(sample) == 0: + return {} + + def _apply(x): + if torch.is_tensor(x): + return f(x) + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + elif isinstance(x, tuple): + return tuple(_apply(x) for x in x) + elif isinstance(x, set): + return {_apply(x) for x in x} + else: + return x + + return _apply(sample) + + +def move_to_cuda(sample, device=None): + device = device or torch.cuda.current_device() + + def _move_to_cuda(tensor): + # non_blocking is ignored if tensor is not pinned, so we can always set + # to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620) + return tensor.cuda(device=device, non_blocking=True) + + return apply_to_sample(_move_to_cuda, sample) + + +def move_to_cpu(sample): + def _move_to_cpu(tensor): + # PyTorch has poor support for half tensors (float16) on CPU. + # Move any such tensors to float32. + if tensor.dtype in {torch.bfloat16, torch.float16}: + tensor = tensor.to(dtype=torch.float32) + return tensor.cpu() + + return apply_to_sample(_move_to_cpu, sample) + + +def get_incremental_state( + module: MultiheadAttention, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, +) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + return module.get_incremental_state(incremental_state, key) + + +def set_incremental_state( + module: MultiheadAttention, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], +) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + result = module.set_incremental_state(incremental_state, key, value) + if result is not None: + incremental_state = result + return incremental_state + + +def load_align_dict(replace_unk): + if replace_unk is None: + align_dict = None + elif isinstance(replace_unk, str) and len(replace_unk) > 0: + # Load alignment dictionary for unknown word replacement if it was passed as an argument. + align_dict = {} + with open(replace_unk, "r") as f: + for line in f: + cols = line.split() + align_dict[cols[0]] = cols[1] + else: + # No alignment dictionary provided but we still want to perform unknown word replacement by copying the + # original source word. + align_dict = {} + return align_dict + + +def print_embed_overlap(embed_dict, vocab_dict): + embed_keys = set(embed_dict.keys()) + vocab_keys = set(vocab_dict.symbols) + overlap = len(embed_keys & vocab_keys) + logger.info("found {}/{} types in embedding file".format(overlap, len(vocab_dict))) + + +def parse_embedding(embed_path): + """Parse embedding text file into a dictionary of word and embedding tensors. + + The first line can have vocabulary size and dimension. The following lines + should contain word and embedding separated by spaces. + + Example: + 2 5 + the -0.0230 -0.0264 0.0287 0.0171 0.1403 + at -0.0395 -0.1286 0.0275 0.0254 -0.0932 + """ + embed_dict = {} + with open(embed_path) as f_embed: + next(f_embed) # skip header + for line in f_embed: + pieces = line.rstrip().split(" ") + embed_dict[pieces[0]] = torch.Tensor( + [float(weight) for weight in pieces[1:]] + ) + return embed_dict + + +def load_embedding(embed_dict, vocab, embedding): + for idx in range(len(vocab)): + token = vocab[idx] + if token in embed_dict: + embedding.weight.data[idx] = embed_dict[token] + return embedding + + +def replace_unk(hypo_str, src_str, alignment, align_dict, unk): + from fairseq import tokenizer + + # Tokens are strings here + hypo_tokens = tokenizer.tokenize_line(hypo_str) + # TODO: Very rare cases where the replacement is '' should be handled gracefully + src_tokens = tokenizer.tokenize_line(src_str) + [""] + for i, ht in enumerate(hypo_tokens): + if ht == unk: + src_token = src_tokens[alignment[i]] + # Either take the corresponding value in the aligned dictionary or just copy the original value. + hypo_tokens[i] = align_dict.get(src_token, src_token) + return " ".join(hypo_tokens) + + +def post_process_prediction( + hypo_tokens, + src_str, + alignment, + align_dict, + tgt_dict, + remove_bpe=None, + extra_symbols_to_ignore=None, +): + hypo_str = tgt_dict.string( + hypo_tokens, remove_bpe, extra_symbols_to_ignore=extra_symbols_to_ignore + ) + if align_dict is not None: + hypo_str = replace_unk( + hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string() + ) + if align_dict is not None or remove_bpe is not None: + # Convert back to tokens for evaluating with unk replacement or without BPE + # Note that the dictionary can be modified inside the method. + hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True) + return hypo_tokens, hypo_str, alignment + + +def make_positions(tensor, padding_idx: int, onnx_trace: bool = False): + """Replace non-padding symbols with their position numbers. + + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx + + +def strip_pad(tensor, pad): + return tensor[tensor.ne(pad)] + + +def buffered_arange(max): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = torch.LongTensor() + if max > buffered_arange.buf.numel(): + buffered_arange.buf.resize_(max) + torch.arange(max, out=buffered_arange.buf) + return buffered_arange.buf[:max] + + +def convert_padding_direction( + src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False +): + assert right_to_left ^ left_to_right + pad_mask = src_tokens.eq(padding_idx) + if not pad_mask.any(): + # no padding, return early + return src_tokens + if left_to_right and not pad_mask[:, 0].any(): + # already right padded + return src_tokens + if right_to_left and not pad_mask[:, -1].any(): + # already left padded + return src_tokens + max_len = src_tokens.size(1) + buffered = torch.empty(0).long() + if max_len > 0: + torch.arange(max_len, out=buffered) + range = buffered.type_as(src_tokens).expand_as(src_tokens) + num_pads = pad_mask.long().sum(dim=1, keepdim=True) + if right_to_left: + index = torch.remainder(range - num_pads, max_len) + else: + index = torch.remainder(range + num_pads, max_len) + return src_tokens.gather(1, index) + + +def item(tensor): + if hasattr(tensor, "item"): + return tensor.item() + if hasattr(tensor, "__getitem__"): + return tensor[0] + return tensor + + +def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: + per_device_grads = {} + norms = [] + for grad in grads: + device = grad.device + cur_device_grads = per_device_grads.get(device) + if cur_device_grads is None: + cur_device_grads = [] + per_device_grads[device] = cur_device_grads + cur_device_grads.append(grad) + for device in per_device_grads.keys(): + cur_device_grads = per_device_grads[device] + if device.type == "cuda": + # TODO(msb) return has_inf + has_inf = torch.zeros((1, 1), dtype=torch.int, device=device) + with torch.cuda.device(device): + norm = multi_tensor_l2norm( + chunk_size, has_inf, [cur_device_grads], False + ) + norms.append(norm[0].to(torch.cuda.current_device())) + else: + norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads] + total_norm = torch.norm(torch.stack(norms)) + return total_norm + + +@torch.no_grad() +def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: + if isinstance(params, torch.Tensor): + params = [params] + params = list(params) + grads = [p.grad.detach() for p in filter(lambda p: p.grad is not None, params)] + if len(grads) == 0: + if len(params) > 0: + return params[0].new_tensor(0.0) + else: + return torch.tensor(0.0) + + if len(grads) == 1: + total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) + else: + if multi_tensor_l2norm_available: + total_norm = multi_tensor_total_norm(grads) + else: + if torch.cuda.is_available(): + warnings.warn( + "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " + "you may get better performance by installing NVIDIA's apex library" + ) + device = torch.cuda.current_device() + elif grads[0].device.type == "xla": + device = grads[0].device + else: + device = torch.device("cpu") + total_norm = torch.norm( + torch.stack( + [torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads] + ) + ) + + if aggregate_norm_fn is not None: + total_norm = aggregate_norm_fn(total_norm) + + if max_norm > 0: + max_norm = float(max_norm) + clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) + for g in grads: + g.mul_(clip_coef) + return total_norm + + +def fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) + + +def _match_types(arg1, arg2): + """Convert the numerical argument to the same type as the other argument""" + + def upgrade(arg_number, arg_structure): + if isinstance(arg_structure, tuple): + return tuple([arg_number] * len(arg_structure)) + elif isinstance(arg_structure, dict): + arg = copy.deepcopy(arg_structure) + for k in arg: + arg[k] = upgrade(arg_number, arg_structure[k]) + return arg + else: + return arg_number + + if isinstance(arg1, float) or isinstance(arg1, int): + return upgrade(arg1, arg2), arg2 + elif isinstance(arg2, float) or isinstance(arg2, int): + return arg1, upgrade(arg2, arg1) + + return arg1, arg2 + + +def resolve_max_positions(*args): + """Resolve max position constraints from multiple sources.""" + + def map_value_update(d1, d2): + updated_value = copy.deepcopy(d1) + for key in d2: + if key not in updated_value: + updated_value[key] = d2[key] + else: + updated_value[key] = min(d1[key], d2[key]) + return updated_value + + def nullsafe_min(l): + minim = None + for item in l: + if minim is None: + minim = item + elif item is not None and item < minim: + minim = item + return minim + + max_positions = None + for arg in args: + if max_positions is None: + max_positions = arg + elif arg is not None: + max_positions, arg = _match_types(max_positions, arg) + if isinstance(arg, float) or isinstance(arg, int): + max_positions = min(max_positions, arg) + elif isinstance(arg, dict): + max_positions = map_value_update(max_positions, arg) + else: + max_positions = tuple(map(nullsafe_min, zip(max_positions, arg))) + + return max_positions + + +def import_user_module(args): + module_path = getattr(args, "user_dir", None) + if module_path is not None: + module_path = os.path.abspath(args.user_dir) + if not os.path.exists(module_path): + fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) + if os.path.exists(fairseq_rel_path): + module_path = fairseq_rel_path + else: + fairseq_rel_path = os.path.join( + os.path.dirname(__file__), "..", args.user_dir + ) + if os.path.exists(fairseq_rel_path): + module_path = fairseq_rel_path + else: + raise FileNotFoundError(module_path) + + # ensure that user modules are only imported once + import_user_module.memo = getattr(import_user_module, "memo", set()) + if module_path not in import_user_module.memo: + import_user_module.memo.add(module_path) + + module_parent, module_name = os.path.split(module_path) + if module_name not in sys.modules: + sys.path.insert(0, module_parent) + importlib.import_module(module_name) + else: + raise ImportError( + "Failed to import --user-dir={} because the corresponding module name " + "({}) is not globally unique. Please rename the directory to " + "something unique and try again.".format(module_path, module_name) + ) + + +def softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def log_softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.log_softmax(x.float(), dim=dim) + else: + return F.log_softmax(x, dim=dim, dtype=torch.float32) + + +def get_perplexity(loss, round=2, base=2): + if loss is None: + return 0.0 + try: + return safe_round(base ** loss, round) + except OverflowError: + return float("inf") + + +def deprecation_warning(message, stacklevel=3): + # don't use DeprecationWarning, since it's ignored by default + warnings.warn(message, stacklevel=stacklevel) + + +def get_activation_fn(activation: str) -> Callable: + """ Returns the activation function corresponding to `activation` """ + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + deprecation_warning( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def get_available_activation_fns() -> List: + return [ + "relu", + "gelu", + "gelu_fast", # deprecated + "gelu_accurate", + "tanh", + "linear", + ] + + +@contextlib.contextmanager +def model_eval(model): + is_training = model.training + model.eval() + yield + model.train(is_training) + + +def has_parameters(module): + try: + next(module.parameters()) + return True + except StopIteration: + return False + + +def get_rng_state(): + state = {"torch_rng_state": torch.get_rng_state()} + if xm is not None: + state["xla_rng_state"] = xm.get_rng_state() + if torch.cuda.is_available(): + state["cuda_rng_state"] = torch.cuda.get_rng_state() + return state + + +def set_rng_state(state): + torch.set_rng_state(state["torch_rng_state"]) + if xm is not None: + xm.set_rng_state(state["xla_rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state["cuda_rng_state"]) + + +class set_torch_seed(object): + def __init__(self, seed): + assert isinstance(seed, int) + self.rng_state = get_rng_state() + + torch.manual_seed(seed) + if xm is not None: + xm.set_rng_state(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + def __enter__(self): + return self + + def __exit__(self, *exc): + set_rng_state(self.rng_state) + + +def parse_alignment(line): + """ + Parses a single line from the alingment file. + + Args: + line (str): String containing the alignment of the format: + - - .. + -. All indices are 0 indexed. + + Returns: + torch.IntTensor: packed alignments of shape (2 * m). + """ + alignments = line.strip().split() + parsed_alignment = torch.IntTensor(2 * len(alignments)) + for idx, alignment in enumerate(alignments): + src_idx, tgt_idx = alignment.split("-") + parsed_alignment[2 * idx] = int(src_idx) + parsed_alignment[2 * idx + 1] = int(tgt_idx) + return parsed_alignment + + +def get_token_to_word_mapping(tokens, exclude_list): + n = len(tokens) + word_start = [int(token not in exclude_list) for token in tokens] + word_idx = list(accumulate(word_start)) + token_to_word = {i: word_idx[i] for i in range(n)} + return token_to_word + + +def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = ( + ((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1) + ) + src_invalid = ( + ((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1) + ) + src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) + tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) + alignment = [] + if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): + attn_valid = attn[tgt_valid] + attn_valid[:, src_invalid] = float("-inf") + _, src_indices = attn_valid.max(dim=1) + for tgt_idx, src_idx in zip(tgt_valid, src_indices): + alignment.append( + ( + src_token_to_word[src_idx.item()] - 1, + tgt_token_to_word[tgt_idx.item()] - 1, + ) + ) + return alignment + + +def new_arange(x, *size): + """ + Return a Tensor of `size` filled with a range function on the device of x. + If size is empty, using the size of the variable x. + """ + if len(size) == 0: + size = x.size() + return torch.arange(size[-1], device=x.device).expand(*size).contiguous() + + +def get_tpu_device(args): + return xm.xla_device() + + +def tpu_data_loader(itr): + import torch_xla.core.xla_model as xm + import torch_xla.distributed.parallel_loader as pl + + xm.rendezvous("tpu_data_loader") # wait for all workers + xm.mark_step() + device = xm.xla_device() + return iterators.CountingIterator( + pl.ParallelLoader(itr, [device]).per_device_loader(device), + start=getattr(itr, "n", 0), + total=len(itr), + ) + + +class CudaEnvironment(object): + def __init__(self): + cur_device = torch.cuda.current_device() + prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device)) + self.name = prop.name + self.major = prop.major + self.minor = prop.minor + self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024 + + @staticmethod + def pretty_print_cuda_env_list(cuda_env_list): + """ + Given a list of CudaEnviorments, pretty print them + """ + num_workers = len(cuda_env_list) + center = "CUDA enviroments for all {} workers".format(num_workers) + banner_len = 40 - len(center) // 2 + first_line = "*" * banner_len + center + "*" * banner_len + logger.info(first_line) + for r, env in enumerate(cuda_env_list): + logger.info( + "rank {:3d}: ".format(r) + + "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor) + + "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB) + + "name = {:40s}".format(env.name) + ) + logger.info(first_line) + + +def csv_str_list(x): + return x.split(",") + + +def eval_str_list(x, type=float): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + try: + return list(map(type, x)) + except TypeError: + return [type(x)] + + +def eval_str_dict(x, type=dict): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + return x + + +def eval_bool(x, default=False): + if x is None: + return default + try: + return bool(eval(x)) + except TypeError: + return default diff --git a/fairseq/version.txt b/fairseq/version.txt new file mode 100644 index 0000000..41432f0 --- /dev/null +++ b/fairseq/version.txt @@ -0,0 +1 @@ +1.0.0a0 diff --git a/fairseq_cli/__init__.py b/fairseq_cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py new file mode 100644 index 0000000..1197d69 --- /dev/null +++ b/fairseq_cli/eval_lm.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Evaluate the perplexity of a trained language model. +""" + +import logging +import math +import os +from argparse import Namespace + +import torch +from fairseq import checkpoint_utils, distributed_utils, options, utils +from fairseq.data import LMContextWindowDataset +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging import progress_bar +from fairseq.logging.meters import StopwatchMeter, TimeMeter +from fairseq.sequence_scorer import SequenceScorer +from omegaconf import DictConfig + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) +logger = logging.getLogger("fairseq_cli.eval_lm") + + +class WordStat(object): + def __init__(self, word, is_bpe): + self.word = word + self.is_bpe = is_bpe + self.log_prob = 0 + self.next_word_prob = 0 + self.count = 0 + self.missing_next_words = 0 + + def add(self, log_prob, next_word_prob): + """increments counters for the sum of log probs of current word and next + word (given context ending at current word). Since the next word might be at the end of the example, + or it might be not counted because it is not an ending subword unit, + also keeps track of how many of those we have seen""" + if next_word_prob is not None: + self.next_word_prob += next_word_prob + else: + self.missing_next_words += 1 + self.log_prob += log_prob + self.count += 1 + + def __str__(self): + return "{}\t{}\t{}\t{}\t{}\t{}".format( + self.word, + self.count, + self.log_prob, + self.is_bpe, + self.next_word_prob, + self.count - self.missing_next_words, + ) + + +def main(cfg: DictConfig, override_args=None, **unused_kwargs): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) + + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + + if use_cuda: + torch.cuda.set_device(cfg.distributed_training.device_id) + + if override_args is not None: + overrides = vars(override_args) + overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) + else: + overrides = None + + logger.info(cfg) + + # Load ensemble + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + + # reduce tokens per sample by the required context window size + cfg.task.tokens_per_sample -= cfg.eval_lm.context_window + + models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [cfg.common_eval.path], + arg_overrides=overrides, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, + ) + + # Load dataset splits + gen_subset = cfg.dataset.gen_subset + task.load_dataset(gen_subset) + dataset = task.dataset(gen_subset) + if cfg.eval_lm.context_window > 0: + dataset = LMContextWindowDataset( + dataset=dataset, + tokens_per_sample=cfg.task.tokens_per_sample, + context_window=cfg.eval_lm.context_window, + pad_idx=task.source_dictionary.pad(), + ) + logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset))) + + # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) + for model in models: + if use_fp16: + model.half() + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(cfg) + + assert len(models) > 0 + + logger.info( + "num. model params: {}".format(sum(p.numel() for p in models[0].parameters())) + ) + + itr = task.get_batch_iterator( + dataset=dataset, + max_tokens=cfg.dataset.max_tokens or 36000, + max_sentences=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + *[model.max_positions() for model in models] + ), + ignore_invalid_inputs=True, + num_shards=max( + cfg.dataset.num_shards, + cfg.distributed_training.distributed_world_size, + ), + shard_id=max( + cfg.dataset.shard_id, + cfg.distributed_training.distributed_rank, + ), + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + gen_timer = StopwatchMeter() + scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch) + + score_sum = 0.0 + count = 0 + + if cfg.common_eval.post_process is not None: + if cfg.common_eval.post_process == "sentencepiece": + raise NotImplementedError + else: + bpe_cont = cfg.common_eval.post_process.rstrip() + bpe_toks = { + i + for i in range(len(task.source_dictionary)) + if task.source_dictionary[i].endswith(bpe_cont) + } + bpe_len = len(bpe_cont) + else: + bpe_toks = None + bpe_len = 0 + + word_stats = dict() + + wps_meter = TimeMeter() + + for sample in progress: + if "net_input" not in sample: + continue + + sample = utils.move_to_cuda(sample) if use_cuda else sample + + gen_timer.start() + hypos = scorer.generate(models, sample) + gen_timer.stop(sample["ntokens"]) + + for i, hypos_i in enumerate(hypos): + hypo = hypos_i[0] + sample_id = sample["id"][i] + + tokens = hypo["tokens"] + tgt_len = tokens.numel() + pos_scores = hypo["positional_scores"].float() + + if cfg.task.add_bos_token: + assert hypo["tokens"][0].item() == task.target_dictionary.bos() + tokens = tokens[1:] + pos_scores = pos_scores[1:] + + skipped_toks = 0 + if bpe_toks is not None: + for i in range(tgt_len - 1): + if tokens[i].item() in bpe_toks: + skipped_toks += 1 + pos_scores[i + 1] += pos_scores[i] + pos_scores[i] = 0 + + inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf")) + if inf_scores.any(): + logger.info( + "skipping tokens with inf scores:", + task.target_dictionary.string(tokens[inf_scores.nonzero()]), + ) + pos_scores = pos_scores[(~inf_scores).nonzero()] + score_sum += pos_scores.sum().cpu() + count += pos_scores.numel() - skipped_toks + + if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats: + w = "" + word_prob = [] + is_bpe = False + for i in range(len(tokens)): + w_ind = tokens[i].item() + w += task.source_dictionary[w_ind] + if bpe_toks is not None and w_ind in bpe_toks: + w = w[:-bpe_len] + is_bpe = True + else: + word_prob.append((w, pos_scores[i].item())) + + next_prob = None + ind = i + 1 + while ind < len(tokens): + if pos_scores[ind].item() != 0: + next_prob = pos_scores[ind] + break + ind += 1 + + word_stats.setdefault(w, WordStat(w, is_bpe)).add( + pos_scores[i].item(), next_prob + ) + is_bpe = False + w = "" + if cfg.eval_lm.output_word_probs: + logger.info( + str(int(sample_id)) + + " " + + ( + "\t".join( + "{} [{:2f}]".format(x[0], x[1]) for x in word_prob + ) + ) + ) + + wps_meter.update(sample["ntokens"]) + progress.log({"wps": round(wps_meter.avg)}) + + avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 + logger.info( + "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format( + gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg + ) + ) + logger.info( + "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( + avg_nll_loss, 2 ** avg_nll_loss + ) + ) + + if cfg.eval_lm.output_word_stats: + for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): + logger.info(ws) + + +def cli_main(): + parser = options.get_eval_lm_parser() + args = options.parse_args_and_arch(parser) + + distributed_utils.call_main(convert_namespace_to_omegaconf(args), main) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py new file mode 100644 index 0000000..021f819 --- /dev/null +++ b/fairseq_cli/generate.py @@ -0,0 +1,393 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Translate pre-processed data with a trained model. +""" + +import ast +import logging +import math +import os +import sys +from argparse import Namespace +from itertools import chain + +import numpy as np +import torch +from fairseq import checkpoint_utils, options, scoring, tasks, utils +from fairseq.data import encoders +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging import progress_bar +from fairseq.logging.meters import StopwatchMeter, TimeMeter +from omegaconf import DictConfig + + +def main(cfg: DictConfig): + + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + assert cfg.common_eval.path is not None, "--path required for generation!" + assert ( + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" + ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" + + if cfg.common_eval.results_path is not None: + os.makedirs(cfg.common_eval.results_path, exist_ok=True) + output_path = os.path.join( + cfg.common_eval.results_path, + "generate-{}.txt".format(cfg.dataset.gen_subset), + ) + with open(output_path, "w", buffering=1, encoding="utf-8") as h: + return _main(cfg, h) + else: + return _main(cfg, sys.stdout) + + +def get_symbols_to_strip_from_output(generator): + if hasattr(generator, "symbols_to_strip_from_output"): + return generator.symbols_to_strip_from_output + else: + return {generator.eos} + + +def _main(cfg: DictConfig, output_file): + logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=output_file, + ) + logger = logging.getLogger("fairseq_cli.generate") + + utils.import_user_module(cfg.common) + + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.max_tokens = 12000 + logger.info(cfg) + + # Fix seed for stochastic decoding + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) + + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + + # Load dataset splits + task = tasks.setup_task(cfg.task) + task.load_dataset(cfg.dataset.gen_subset) + + # Set dictionaries + try: + src_dict = getattr(task, "source_dictionary", None) + except NotImplementedError: + src_dict = None + tgt_dict = task.target_dictionary + + overrides = ast.literal_eval(cfg.common_eval.model_overrides) + + # Load ensemble + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + models, _model_args = checkpoint_utils.load_model_ensemble( + utils.split_paths(cfg.common_eval.path), + arg_overrides=overrides, + task=task, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, + ) + + if cfg.generation.lm_path is not None: + overrides["data"] = cfg.task.data + + try: + lms, _ = checkpoint_utils.load_model_ensemble( + [cfg.generation.lm_path], arg_overrides=overrides, task=None + ) + except: + logger.warning( + f"Failed to load language model! Please make sure that the language model dict is the same " + f"as target dict and is located in the data dir ({cfg.task.data})" + ) + raise + + assert len(lms) == 1 + else: + lms = [None] + + # Optimize ensemble for generation + for model in chain(models, lms): + if model is None: + continue + if cfg.common.fp16: + model.half() + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(cfg) + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + align_dict = utils.load_align_dict(cfg.generation.replace_unk) + + # Load dataset (possibly sharded) + itr = task.get_batch_iterator( + dataset=task.dataset(cfg.dataset.gen_subset), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + task.max_positions(), *[m.max_positions() for m in models] + ), + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + # Initialize generator + gen_timer = StopwatchMeter() + + extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} + generator = task.build_generator( + models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) + + # Handle tokenization and BPE + tokenizer = encoders.build_tokenizer(cfg.tokenizer) + bpe = encoders.build_bpe(cfg.bpe) + + def decode_fn(x): + if bpe is not None: + x = bpe.decode(x) + if tokenizer is not None: + x = tokenizer.decode(x) + return x + + scorer = scoring.build_scorer(cfg.scoring, tgt_dict) + + num_sentences = 0 + has_target = True + wps_meter = TimeMeter() + for sample in progress: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if "net_input" not in sample: + continue + + prefix_tokens = None + if cfg.generation.prefix_size > 0: + prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] + + constraints = None + if "constraints" in sample: + constraints = sample["constraints"] + + gen_timer.start() + hypos = task.inference_step( + generator, + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + ) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) + gen_timer.stop(num_generated_tokens) + + for i, sample_id in enumerate(sample["id"].tolist()): + has_target = sample["target"] is not None + + # Remove padding + if "src_tokens" in sample["net_input"]: + src_tokens = utils.strip_pad( + sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() + ) + else: + src_tokens = None + + target_tokens = None + if has_target: + target_tokens = ( + utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu() + ) + + # Either retrieve the original sentences or regenerate them from tokens. + if align_dict is not None: + src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( + sample_id + ) + target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( + sample_id + ) + else: + if src_dict is not None: + src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) + else: + src_str = "" + if has_target: + target_str = tgt_dict.string( + target_tokens, + cfg.common_eval.post_process, + escape_unk=True, + extra_symbols_to_ignore=get_symbols_to_strip_from_output( + generator + ), + ) + + src_str = decode_fn(src_str) + if has_target: + target_str = decode_fn(target_str) + + if not cfg.common_eval.quiet: + if src_dict is not None: + print("S-{}\t{}".format(sample_id, src_str), file=output_file) + if has_target: + print("T-{}\t{}".format(sample_id, target_str), file=output_file) + + # Process top predictions + for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): + hypo_tokens, hypo_str, alignment = utils.post_process_prediction( + hypo_tokens=hypo["tokens"].int().cpu(), + src_str=src_str, + alignment=hypo["alignment"], + align_dict=align_dict, + tgt_dict=tgt_dict, + remove_bpe=cfg.common_eval.post_process, + extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), + ) + detok_hypo_str = decode_fn(hypo_str) + if not cfg.common_eval.quiet: + score = hypo["score"] / math.log(2) # convert to base 2 + # original hypothesis (after tokenization and BPE) + print( + "H-{}\t{}\t{}".format(sample_id, score, hypo_str), + file=output_file, + ) + # detokenized hypothesis + print( + "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str), + file=output_file, + ) + print( + "P-{}\t{}".format( + sample_id, + " ".join( + map( + lambda x: "{:.4f}".format(x), + # convert from base e to base 2 + hypo["positional_scores"] + .div_(math.log(2)) + .tolist(), + ) + ), + ), + file=output_file, + ) + + if cfg.generation.print_alignment: + print( + "A-{}\t{}".format( + sample_id, + " ".join( + [ + "{}-{}".format(src_idx, tgt_idx) + for src_idx, tgt_idx in alignment + ] + ), + ), + file=output_file, + ) + + if cfg.generation.print_step: + print( + "I-{}\t{}".format(sample_id, hypo["steps"]), + file=output_file, + ) + + if cfg.generation.retain_iter_history: + for step, h in enumerate(hypo["history"]): + _, h_str, _ = utils.post_process_prediction( + hypo_tokens=h["tokens"].int().cpu(), + src_str=src_str, + alignment=None, + align_dict=None, + tgt_dict=tgt_dict, + remove_bpe=None, + ) + print( + "E-{}_{}\t{}".format(sample_id, step, h_str), + file=output_file, + ) + + # Score only the top hypothesis + if has_target and j == 0: + if align_dict is not None or cfg.common_eval.post_process is not None: + # Convert back to tokens for evaluation with unk replacement and/or without BPE + target_tokens = tgt_dict.encode_line( + target_str, add_if_not_exist=True + ) + hypo_tokens = tgt_dict.encode_line( + detok_hypo_str, add_if_not_exist=True + ) + if hasattr(scorer, "add_string"): + scorer.add_string(target_str, detok_hypo_str) + else: + scorer.add(target_tokens, hypo_tokens) + + wps_meter.update(num_generated_tokens) + progress.log({"wps": round(wps_meter.avg)}) + num_sentences += ( + sample["nsentences"] if "nsentences" in sample else sample["id"].numel() + ) + + logger.info("NOTE: hypothesis and token scores are output in base 2") + logger.info( + "Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( + num_sentences, + gen_timer.n, + gen_timer.sum, + num_sentences / gen_timer.sum, + 1.0 / gen_timer.avg, + ) + ) + if has_target: + if cfg.bpe and not cfg.generation.sacrebleu: + if cfg.common_eval.post_process: + logger.warning( + "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" + ) + else: + logger.warning( + "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization" + ) + # use print to be consistent with other main outputs: S-, H-, T-, D- and so on + print( + "Generate {} with beam={}: {}".format( + cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() + ), + file=output_file, + ) + + return scorer + + +def cli_main(): + parser = options.get_generation_parser() + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py new file mode 100644 index 0000000..530830d --- /dev/null +++ b/fairseq_cli/interactive.py @@ -0,0 +1,322 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Translate raw text with a trained model. Batches data on-the-fly. +""" + +import ast +import fileinput +import logging +import math +import os +import sys +import time +from argparse import Namespace +from collections import namedtuple + +import numpy as np +import torch +from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils +from fairseq.data import encoders +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.token_generation_constraints import pack_constraints, unpack_constraints +from fairseq_cli.generate import get_symbols_to_strip_from_output + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.interactive") + + +Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") +Translation = namedtuple("Translation", "src_str hypos pos_scores alignments") + + +def buffered_read(input, buffer_size): + buffer = [] + with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h: + for src_str in h: + buffer.append(src_str.strip()) + if len(buffer) >= buffer_size: + yield buffer + buffer = [] + + if len(buffer) > 0: + yield buffer + + +def make_batches(lines, cfg, task, max_positions, encode_fn): + def encode_fn_target(x): + return encode_fn(x) + + if cfg.generation.constraints: + # Strip (tab-delimited) contraints, if present, from input lines, + # store them in batch_constraints + batch_constraints = [list() for _ in lines] + for i, line in enumerate(lines): + if "\t" in line: + lines[i], *batch_constraints[i] = line.split("\t") + + # Convert each List[str] to List[Tensor] + for i, constraint_list in enumerate(batch_constraints): + batch_constraints[i] = [ + task.target_dictionary.encode_line( + encode_fn_target(constraint), + append_eos=False, + add_if_not_exist=False, + ) + for constraint in constraint_list + ] + + tokens = [ + task.source_dictionary.encode_line( + encode_fn(src_str), add_if_not_exist=False + ).long() + for src_str in lines + ] + + if cfg.generation.constraints: + constraints_tensor = pack_constraints(batch_constraints) + else: + constraints_tensor = None + + lengths = [t.numel() for t in tokens] + itr = task.get_batch_iterator( + dataset=task.build_dataset_for_inference( + tokens, lengths, constraints=constraints_tensor + ), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, + max_positions=max_positions, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + ).next_epoch_itr(shuffle=False) + for batch in itr: + ids = batch["id"] + src_tokens = batch["net_input"]["src_tokens"] + src_lengths = batch["net_input"]["src_lengths"] + constraints = batch.get("constraints", None) + + yield Batch( + ids=ids, + src_tokens=src_tokens, + src_lengths=src_lengths, + constraints=constraints, + ) + + +def main(cfg: FairseqConfig): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + start_time = time.time() + total_translate_time = 0 + + utils.import_user_module(cfg.common) + + if cfg.interactive.buffer_size < 1: + cfg.interactive.buffer_size = 1 + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.batch_size = 1 + + assert ( + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + not cfg.dataset.batch_size + or cfg.dataset.batch_size <= cfg.interactive.buffer_size + ), "--batch-size cannot be larger than --buffer-size" + + logger.info(cfg) + + # Fix seed for stochastic decoding + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) + + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + + # Setup task, e.g., translation + task = tasks.setup_task(cfg.task) + + # Load ensemble + overrides = ast.literal_eval(cfg.common_eval.model_overrides) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + models, _model_args = checkpoint_utils.load_model_ensemble( + utils.split_paths(cfg.common_eval.path), + arg_overrides=overrides, + task=task, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, + ) + + # Set dictionaries + src_dict = task.source_dictionary + tgt_dict = task.target_dictionary + + # Optimize ensemble for generation + for model in models: + if model is None: + continue + if cfg.common.fp16: + model.half() + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(cfg) + + # Initialize generator + generator = task.build_generator(models, cfg.task) + + # Handle tokenization and BPE + tokenizer = encoders.build_tokenizer(cfg.tokenizer) + bpe = encoders.build_bpe(cfg.bpe) + + def encode_fn(x): + if tokenizer is not None: + x = tokenizer.encode(x) + if bpe is not None: + x = bpe.encode(x) + return x + + def decode_fn(x): + if bpe is not None: + x = bpe.decode(x) + if tokenizer is not None: + x = tokenizer.decode(x) + return x + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + align_dict = utils.load_align_dict(cfg.generation.replace_unk) + + max_positions = utils.resolve_max_positions( + task.max_positions(), *[model.max_positions() for model in models] + ) + + if cfg.generation.constraints: + logger.warning( + "NOTE: Constrained decoding currently assumes a shared subword vocabulary." + ) + + if cfg.interactive.buffer_size > 1: + logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size) + logger.info("NOTE: hypothesis and token scores are output in base 2") + logger.info("Type the input sentence and press return:") + start_id = 0 + for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size): + results = [] + for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): + bsz = batch.src_tokens.size(0) + src_tokens = batch.src_tokens + src_lengths = batch.src_lengths + constraints = batch.constraints + if use_cuda: + src_tokens = src_tokens.cuda() + src_lengths = src_lengths.cuda() + if constraints is not None: + constraints = constraints.cuda() + + sample = { + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + }, + } + translate_start_time = time.time() + translations = task.inference_step( + generator, models, sample, constraints=constraints + ) + translate_time = time.time() - translate_start_time + total_translate_time += translate_time + list_constraints = [[] for _ in range(bsz)] + if cfg.generation.constraints: + list_constraints = [unpack_constraints(c) for c in constraints] + for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): + src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) + constraints = list_constraints[i] + results.append( + ( + start_id + id, + src_tokens_i, + hypos, + { + "constraints": constraints, + "time": translate_time / len(translations), + }, + ) + ) + + # sort output to match input order + for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): + if src_dict is not None: + src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) + print("S-{}\t{}".format(id_, src_str)) + print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) + for constraint in info["constraints"]: + print( + "C-{}\t{}".format( + id_, tgt_dict.string(constraint, cfg.common_eval.post_process) + ) + ) + + # Process top predictions + for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]: + hypo_tokens, hypo_str, alignment = utils.post_process_prediction( + hypo_tokens=hypo["tokens"].int().cpu(), + src_str=src_str, + alignment=hypo["alignment"], + align_dict=align_dict, + tgt_dict=tgt_dict, + remove_bpe=cfg.common_eval.post_process, + extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), + ) + detok_hypo_str = decode_fn(hypo_str) + score = hypo["score"] / math.log(2) # convert to base 2 + # original hypothesis (after tokenization and BPE) + print("H-{}\t{}\t{}".format(id_, score, hypo_str)) + # detokenized hypothesis + print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str)) + print( + "P-{}\t{}".format( + id_, + " ".join( + map( + lambda x: "{:.4f}".format(x), + # convert from base e to base 2 + hypo["positional_scores"].div_(math.log(2)).tolist(), + ) + ), + ) + ) + if cfg.generation.print_alignment: + alignment_str = " ".join( + ["{}-{}".format(src, tgt) for src, tgt in alignment] + ) + print("A-{}\t{}".format(id_, alignment_str)) + + # update running id_ counter + start_id += len(inputs) + + logger.info( + "Total time: {:.3f} seconds; translation time: {:.3f}".format( + time.time() - start_time, total_translate_time + ) + ) + + +def cli_main(): + parser = options.get_interactive_generation_parser() + args = options.parse_args_and_arch(parser) + distributed_utils.call_main(convert_namespace_to_omegaconf(args), main) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq_cli/preprocess.py b/fairseq_cli/preprocess.py new file mode 100644 index 0000000..fa77da8 --- /dev/null +++ b/fairseq_cli/preprocess.py @@ -0,0 +1,398 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Data pre-processing: build vocabularies and binarize training data. +""" + +import logging +import os +import shutil +import sys +from collections import Counter +from itertools import zip_longest +from multiprocessing import Pool + +from fairseq import options, tasks, utils +from fairseq.binarizer import Binarizer +from fairseq.data import indexed_dataset + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.preprocess") + + +def main(args): + utils.import_user_module(args) + + os.makedirs(args.destdir, exist_ok=True) + + logger.addHandler( + logging.FileHandler( + filename=os.path.join(args.destdir, "preprocess.log"), + ) + ) + logger.info(args) + + task = tasks.get_task(args.task) + + def train_path(lang): + return "{}{}".format(args.trainpref, ("." + lang) if lang else "") + + def file_name(prefix, lang): + fname = prefix + if lang is not None: + fname += ".{lang}".format(lang=lang) + return fname + + def dest_path(prefix, lang): + return os.path.join(args.destdir, file_name(prefix, lang)) + + def dict_path(lang): + return dest_path("dict", lang) + ".txt" + + def build_dictionary(filenames, src=False, tgt=False): + assert src ^ tgt + return task.build_dictionary( + filenames, + workers=args.workers, + threshold=args.thresholdsrc if src else args.thresholdtgt, + nwords=args.nwordssrc if src else args.nwordstgt, + padding_factor=args.padding_factor, + ) + + target = not args.only_source + + if not args.srcdict and os.path.exists(dict_path(args.source_lang)): + raise FileExistsError(dict_path(args.source_lang)) + if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)): + raise FileExistsError(dict_path(args.target_lang)) + + if args.joined_dictionary: + assert ( + not args.srcdict or not args.tgtdict + ), "cannot use both --srcdict and --tgtdict with --joined-dictionary" + + if args.srcdict: + src_dict = task.load_dictionary(args.srcdict) + elif args.tgtdict: + src_dict = task.load_dictionary(args.tgtdict) + else: + assert ( + args.trainpref + ), "--trainpref must be set if --srcdict is not specified" + src_dict = build_dictionary( + {train_path(lang) for lang in [args.source_lang, args.target_lang]}, + src=True, + ) + tgt_dict = src_dict + else: + if args.srcdict: + src_dict = task.load_dictionary(args.srcdict) + else: + assert ( + args.trainpref + ), "--trainpref must be set if --srcdict is not specified" + src_dict = build_dictionary([train_path(args.source_lang)], src=True) + + if target: + if args.tgtdict: + tgt_dict = task.load_dictionary(args.tgtdict) + else: + assert ( + args.trainpref + ), "--trainpref must be set if --tgtdict is not specified" + tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True) + else: + tgt_dict = None + + src_dict.save(dict_path(args.source_lang)) + if target and tgt_dict is not None: + tgt_dict.save(dict_path(args.target_lang)) + + def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): + logger.info("[{}] Dictionary: {} types".format(lang, len(vocab))) + n_seq_tok = [0, 0] + replaced = Counter() + + def merge_result(worker_result): + replaced.update(worker_result["replaced"]) + n_seq_tok[0] += worker_result["nseq"] + n_seq_tok[1] += worker_result["ntok"] + + input_file = "{}{}".format( + input_prefix, ("." + lang) if lang is not None else "" + ) + offsets = Binarizer.find_offsets(input_file, num_workers) + pool = None + if num_workers > 1: + pool = Pool(processes=num_workers - 1) + for worker_id in range(1, num_workers): + prefix = "{}{}".format(output_prefix, worker_id) + pool.apply_async( + binarize, + ( + args, + input_file, + vocab, + prefix, + lang, + offsets[worker_id], + offsets[worker_id + 1], + ), + callback=merge_result, + ) + pool.close() + + ds = indexed_dataset.make_builder( + dataset_dest_file(args, output_prefix, lang, "bin"), + impl=args.dataset_impl, + vocab_size=len(vocab), + ) + merge_result( + Binarizer.binarize( + input_file, vocab, lambda t: ds.add_item(t), offset=0, end=offsets[1] + ) + ) + if num_workers > 1: + pool.join() + for worker_id in range(1, num_workers): + prefix = "{}{}".format(output_prefix, worker_id) + temp_file_path = dataset_dest_prefix(args, prefix, lang) + ds.merge_file_(temp_file_path) + os.remove(indexed_dataset.data_file_path(temp_file_path)) + os.remove(indexed_dataset.index_file_path(temp_file_path)) + + ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) + + logger.info( + "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format( + lang, + input_file, + n_seq_tok[0], + n_seq_tok[1], + 100 * sum(replaced.values()) / n_seq_tok[1], + vocab.unk_word, + ) + ) + + def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers): + nseq = [0] + + def merge_result(worker_result): + nseq[0] += worker_result["nseq"] + + input_file = input_prefix + offsets = Binarizer.find_offsets(input_file, num_workers) + pool = None + if num_workers > 1: + pool = Pool(processes=num_workers - 1) + for worker_id in range(1, num_workers): + prefix = "{}{}".format(output_prefix, worker_id) + pool.apply_async( + binarize_alignments, + ( + args, + input_file, + utils.parse_alignment, + prefix, + offsets[worker_id], + offsets[worker_id + 1], + ), + callback=merge_result, + ) + pool.close() + + ds = indexed_dataset.make_builder( + dataset_dest_file(args, output_prefix, None, "bin"), impl=args.dataset_impl + ) + + merge_result( + Binarizer.binarize_alignments( + input_file, + utils.parse_alignment, + lambda t: ds.add_item(t), + offset=0, + end=offsets[1], + ) + ) + if num_workers > 1: + pool.join() + for worker_id in range(1, num_workers): + prefix = "{}{}".format(output_prefix, worker_id) + temp_file_path = dataset_dest_prefix(args, prefix, None) + ds.merge_file_(temp_file_path) + os.remove(indexed_dataset.data_file_path(temp_file_path)) + os.remove(indexed_dataset.index_file_path(temp_file_path)) + + ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) + + logger.info("[alignments] {}: parsed {} alignments".format(input_file, nseq[0])) + + def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): + if args.dataset_impl == "raw": + # Copy original text file to destination folder + output_text_file = dest_path( + output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), + lang, + ) + shutil.copyfile(file_name(input_prefix, lang), output_text_file) + else: + make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers) + + def make_all(lang, vocab): + if args.trainpref: + make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers) + if args.validpref: + for k, validpref in enumerate(args.validpref.split(",")): + outprefix = "valid{}".format(k) if k > 0 else "valid" + make_dataset( + vocab, validpref, outprefix, lang, num_workers=args.workers + ) + if args.testpref: + for k, testpref in enumerate(args.testpref.split(",")): + outprefix = "test{}".format(k) if k > 0 else "test" + make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers) + + def make_all_alignments(): + if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix): + make_binary_alignment_dataset( + args.trainpref + "." + args.align_suffix, + "train.align", + num_workers=args.workers, + ) + if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix): + make_binary_alignment_dataset( + args.validpref + "." + args.align_suffix, + "valid.align", + num_workers=args.workers, + ) + if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix): + make_binary_alignment_dataset( + args.testpref + "." + args.align_suffix, + "test.align", + num_workers=args.workers, + ) + + make_all(args.source_lang, src_dict) + if target: + make_all(args.target_lang, tgt_dict) + if args.align_suffix: + make_all_alignments() + + logger.info("Wrote preprocessed data to {}".format(args.destdir)) + + if args.alignfile: + assert args.trainpref, "--trainpref must be set if --alignfile is specified" + src_file_name = train_path(args.source_lang) + tgt_file_name = train_path(args.target_lang) + freq_map = {} + with open(args.alignfile, "r", encoding="utf-8") as align_file: + with open(src_file_name, "r", encoding="utf-8") as src_file: + with open(tgt_file_name, "r", encoding="utf-8") as tgt_file: + for a, s, t in zip_longest(align_file, src_file, tgt_file): + si = src_dict.encode_line(s, add_if_not_exist=False) + ti = tgt_dict.encode_line(t, add_if_not_exist=False) + ai = list(map(lambda x: tuple(x.split("-")), a.split())) + for sai, tai in ai: + srcidx = si[int(sai)] + tgtidx = ti[int(tai)] + if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk(): + assert srcidx != src_dict.pad() + assert srcidx != src_dict.eos() + assert tgtidx != tgt_dict.pad() + assert tgtidx != tgt_dict.eos() + + if srcidx not in freq_map: + freq_map[srcidx] = {} + if tgtidx not in freq_map[srcidx]: + freq_map[srcidx][tgtidx] = 1 + else: + freq_map[srcidx][tgtidx] += 1 + + align_dict = {} + for srcidx in freq_map.keys(): + align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) + + with open( + os.path.join( + args.destdir, + "alignment.{}-{}.txt".format(args.source_lang, args.target_lang), + ), + "w", + encoding="utf-8", + ) as f: + for k, v in align_dict.items(): + print("{} {}".format(src_dict[k], tgt_dict[v]), file=f) + + +def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True): + ds = indexed_dataset.make_builder( + dataset_dest_file(args, output_prefix, lang, "bin"), + impl=args.dataset_impl, + vocab_size=len(vocab), + ) + + def consumer(tensor): + ds.add_item(tensor) + + res = Binarizer.binarize( + filename, vocab, consumer, append_eos=append_eos, offset=offset, end=end + ) + ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) + return res + + +def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end): + ds = indexed_dataset.make_builder( + dataset_dest_file(args, output_prefix, None, "bin"), + impl=args.dataset_impl, + vocab_size=None, + ) + + def consumer(tensor): + ds.add_item(tensor) + + res = Binarizer.binarize_alignments( + filename, parse_alignment, consumer, offset=offset, end=end + ) + ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) + return res + + +def dataset_dest_prefix(args, output_prefix, lang): + base = "{}/{}".format(args.destdir, output_prefix) + if lang is not None: + lang_part = ".{}-{}.{}".format(args.source_lang, args.target_lang, lang) + elif args.only_source: + lang_part = "" + else: + lang_part = ".{}-{}".format(args.source_lang, args.target_lang) + + return "{}{}".format(base, lang_part) + + +def dataset_dest_file(args, output_prefix, lang, extension): + base = dataset_dest_prefix(args, output_prefix, lang) + return "{}.{}".format(base, extension) + + +def get_offsets(input_file, num_workers): + return Binarizer.find_offsets(input_file, num_workers) + + +def cli_main(): + parser = options.get_preprocessing_parser() + args = parser.parse_args() + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq_cli/score.py b/fairseq_cli/score.py new file mode 100644 index 0000000..e06d672 --- /dev/null +++ b/fairseq_cli/score.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +BLEU scoring of generated translations against reference translations. +""" + +import argparse +import os +import sys + +from fairseq.data import dictionary +from fairseq.scoring import bleu + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Command-line script for BLEU scoring." + ) + # fmt: off + parser.add_argument('-s', '--sys', default='-', help='system output') + parser.add_argument('-r', '--ref', required=True, help='references') + parser.add_argument('-o', '--order', default=4, metavar='N', + type=int, help='consider ngrams up to this order') + parser.add_argument('--ignore-case', action='store_true', + help='case-insensitive scoring') + parser.add_argument('--sacrebleu', action='store_true', + help='score with sacrebleu') + parser.add_argument('--sentence-bleu', action='store_true', + help='report sentence-level BLEUs (i.e., with +1 smoothing)') + # fmt: on + return parser + + +def cli_main(): + parser = get_parser() + args = parser.parse_args() + print(args) + + assert args.sys == "-" or os.path.exists( + args.sys + ), "System output file {} does not exist".format(args.sys) + assert os.path.exists(args.ref), "Reference file {} does not exist".format(args.ref) + + dict = dictionary.Dictionary() + + def readlines(fd): + for line in fd.readlines(): + if args.ignore_case: + yield line.lower() + else: + yield line + + if args.sacrebleu: + import sacrebleu + + def score(fdsys): + with open(args.ref) as fdref: + print(sacrebleu.corpus_bleu(fdsys, [fdref])) + + elif args.sentence_bleu: + + def score(fdsys): + with open(args.ref) as fdref: + scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) + for i, (sys_tok, ref_tok) in enumerate( + zip(readlines(fdsys), readlines(fdref)) + ): + scorer.reset(one_init=True) + sys_tok = dict.encode_line(sys_tok) + ref_tok = dict.encode_line(ref_tok) + scorer.add(ref_tok, sys_tok) + print(i, scorer.result_string(args.order)) + + else: + + def score(fdsys): + with open(args.ref) as fdref: + scorer = bleu.Scorer( + bleu.BleuConfig( + pad=dict.pad(), + eos=dict.eos(), + unk=dict.unk(), + ) + ) + for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): + sys_tok = dict.encode_line(sys_tok) + ref_tok = dict.encode_line(ref_tok) + scorer.add(ref_tok, sys_tok) + print(scorer.result_string(args.order)) + + if args.sys == "-": + score(sys.stdin) + else: + with open(args.sys, "r") as f: + score(f) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py new file mode 100644 index 0000000..ec10028 --- /dev/null +++ b/fairseq_cli/train.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Train a new model on one or across multiple GPUs. +""" + +import argparse +import logging +import math +import os +import sys +from typing import Dict, Optional, Any, List, Tuple, Callable + +import numpy as np +import torch + +from fairseq import ( + checkpoint_utils, + distributed_utils, + options, + quantization_utils, + tasks, + utils, +) +from fairseq.data import iterators +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging import meters, metrics, progress_bar +from fairseq.model_parallel.megatron_trainer import MegatronTrainer +from omegaconf import DictConfig +from fairseq.trainer import Trainer + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.train") + + +def main(cfg: DictConfig) -> None: + if isinstance(cfg, argparse.Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) + + assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \ + 'Must specify batch size either with --max-tokens or --batch-size' + metrics.reset() + + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) + + if distributed_utils.is_master(cfg.distributed_training): + checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) + + # Print args + logger.info(cfg) + + # Setup task, e.g., translation, language modeling, etc. + task = tasks.setup_task(cfg.task) + # Load valid dataset (we load training data below, based on the latest checkpoint) + for valid_sub_split in cfg.dataset.valid_subset.split(','): + task.load_dataset(valid_sub_split, combine=False, epoch=1) + + # Build model and criterion + model = task.build_model(cfg.model) + criterion = task.build_criterion(cfg.criterion) + logger.info(model) + logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__)) + logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__)) + logger.info( + "criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__) + ) + logger.info("num. model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + )) + + # (optionally) Configure quantization + if cfg.common.quantization_config_path is not None: + quantizer = quantization_utils.Quantizer( + config_path=cfg.common.quantization_config_path, + max_epoch=cfg.optimization.max_epoch, + max_update=cfg.optimization.max_update, + ) + else: + quantizer = None + + # Build trainer + if cfg.common.model_parallel_size == 1: + trainer = Trainer(cfg, task, model, criterion, quantizer) + else: + trainer = MegatronTrainer(cfg, task, model, criterion) + + logger.info('training on {} devices (GPUs/TPUs)'.format(cfg.distributed_training.distributed_world_size)) + logger.info('max tokens per GPU = {} and batch size per GPU = {}'.format( + cfg.dataset.max_tokens, + cfg.dataset.batch_size, + )) + + # Load the latest checkpoint if one is available and restore the + # corresponding train iterator + extra_state, epoch_itr = checkpoint_utils.load_checkpoint( + cfg.checkpoint, + trainer, + # don't cache epoch iterators for sharded datasets + disable_iterator_cache=task.has_sharded_data("train"), + ) + + max_epoch = cfg.optimization.max_epoch or math.inf + lr = trainer.get_lr() + train_meter = meters.StopwatchMeter() + train_meter.start() + while ( + lr > cfg.optimization.min_lr + and epoch_itr.next_epoch_idx <= max_epoch + ): + # train for one epoch + valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) + if should_stop: + break + + # only use first validation loss to update the learning rate + lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) + + epoch_itr = trainer.get_train_iterator( + epoch_itr.next_epoch_idx, + # sharded data: get train iterator for next epoch + load_dataset=task.has_sharded_data("train"), + # don't cache epoch iterators for sharded datasets + disable_iterator_cache=task.has_sharded_data("train"), + ) + train_meter.stop() + logger.info("done training in {:.1f} seconds".format(train_meter.sum)) + + +def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: + # skip check if no validation was done in the current epoch + if valid_loss is None: + return False + if cfg.checkpoint.patience <= 0: + return False + + def is_better(a, b): + return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b + + prev_best = getattr(should_stop_early, "best", None) + if prev_best is None or is_better(valid_loss, prev_best): + should_stop_early.best = valid_loss + should_stop_early.num_runs = 0 + return False + else: + should_stop_early.num_runs += 1 + if should_stop_early.num_runs >= cfg.checkpoint.patience: + logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(cfg.checkpoint.patience)) + return True + else: + return False + + +@metrics.aggregate("train") +def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: + """Train the model for one epoch and return validation losses.""" + # Initialize data iterator + itr = epoch_itr.next_epoch_itr( + fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, + shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), + ) + update_freq = ( + cfg.optimization.update_freq[epoch_itr.epoch - 1] + if epoch_itr.epoch <= len(cfg.optimization.update_freq) + else cfg.optimization.update_freq[-1] + ) + itr = iterators.GroupedIterator(itr, update_freq) + if getattr(cfg.common, "tpu", False): + itr = utils.tpu_data_loader(itr) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + epoch=epoch_itr.epoch, + tensorboard_logdir=( + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + ), + default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'), + ) + + trainer.begin_epoch(epoch_itr.epoch) + + valid_subsets = cfg.dataset.valid_subset.split(',') + should_stop = False + num_updates = trainer.get_num_updates() + for i, samples in enumerate(progress): + with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( + "train_step-%d" % i + ): + log_output = trainer.train_step(samples) + + if log_output is not None: # not OOM, overflow, ... + # log mid-epoch stats + num_updates = trainer.get_num_updates() + if num_updates % cfg.common.log_interval == 0: + stats = get_training_stats(metrics.get_smoothed_values("train_inner")) + progress.log(stats, tag="train_inner", step=num_updates) + + # reset mid-epoch stats after each log interval + # the end-of-epoch stats will still be preserved + metrics.reset_meters("train_inner") + + end_of_epoch = not itr.has_next() + valid_losses, should_stop = validate_and_save( + cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch + ) + + if should_stop: + break + + # log end-of-epoch stats + logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch)) + stats = get_training_stats(metrics.get_smoothed_values("train")) + progress.print(stats, tag="train", step=num_updates) + + # reset epoch-level meters + metrics.reset_meters("train") + return valid_losses, should_stop + + +def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]: + num_updates = trainer.get_num_updates() + max_update = cfg.optimization.max_update or math.inf + do_save = ( + (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) + or num_updates >= max_update + or ( + cfg.checkpoint.save_interval_updates > 0 + and num_updates > 0 + and num_updates % cfg.checkpoint.save_interval_updates == 0 + and num_updates >= cfg.dataset.validate_after_updates + ) + ) + do_validate = ( + (not end_of_epoch and do_save) # validate during mid-epoch saves + or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) + or num_updates >= max_update + or ( + cfg.dataset.validate_interval_updates > 0 + and num_updates > 0 + and num_updates % cfg.dataset.validate_interval_updates == 0 + ) + ) and not cfg.dataset.disable_validation + + # Validate + valid_losses = [None] + if do_validate: + valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) + + # Stopping conditions + should_stop = ( + should_stop_early(cfg, valid_losses[0]) + or num_updates >= max_update + or ( + cfg.optimization.stop_time_hours > 0 + and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours + ) + ) + + # Save checkpoint + if do_save or should_stop: + logger.info("begin save checkpoint") + checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0]) + + return valid_losses, should_stop + + +def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: + stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) + return stats + + +def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]: + """Evaluate the model on the validation set(s) and return the losses.""" + + if cfg.dataset.fixed_validation_seed is not None: + # set fixed seed for every validation + utils.set_torch_seed(cfg.dataset.fixed_validation_seed) + + trainer.begin_valid_epoch(epoch_itr.epoch) + valid_losses = [] + for subset in subsets: + logger.info('begin validation on "{}" subset'.format(subset)) + + # Initialize data iterator + itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) + if cfg.common.tpu: + itr = utils.tpu_data_loader(itr) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + epoch=epoch_itr.epoch, + prefix=f"valid on '{subset}' subset", + tensorboard_logdir=( + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None + ), + default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'), + ) + + # create a new root metrics aggregator so validation metrics + # don't pollute other aggregators (e.g., train meters) + with metrics.aggregate(new_root=True) as agg: + for sample in progress: + trainer.valid_step(sample) + + # log validation stats + stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) + progress.print(stats, tag=subset, step=trainer.get_num_updates()) + + valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) + return valid_losses + + +def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]: + stats["num_updates"] = trainer.get_num_updates() + if hasattr(checkpoint_utils.save_checkpoint, "best"): + key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) + best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min + stats[key] = best_function( + checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric] + ) + return stats + + +def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None: + parser = options.get_training_parser() + args = options.parse_args_and_arch(parser, modify_parser=modify_parser) + + cfg = convert_namespace_to_omegaconf(args) + + if args.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + distributed_utils.call_main(cfg, main) + else: + distributed_utils.call_main(cfg, main) + + +if __name__ == '__main__': + cli_main() diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py new file mode 100644 index 0000000..a1e577e --- /dev/null +++ b/fairseq_cli/validate.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 -u +# !/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +from argparse import Namespace +from itertools import chain + +import torch +from fairseq import checkpoint_utils, distributed_utils, options, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging import metrics, progress_bar +from omegaconf import DictConfig + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.validate") + + +def main(cfg: DictConfig, override_args=None): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) + + assert ( + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" + + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + + if use_cuda: + torch.cuda.set_device(cfg.distributed_training.device_id) + + if override_args is not None: + overrides = vars(override_args) + overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) + else: + overrides = None + + # Load ensemble + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [cfg.common_eval.path], + arg_overrides=overrides, + suffix=cfg.checkpoint.checkpoint_suffix, + ) + model = models[0] + + # Move models to GPU + for model in models: + if use_fp16: + model.half() + if use_cuda: + model.cuda() + + # Print args + logger.info(model_args) + + # Build criterion + criterion = task.build_criterion(model_args.criterion) + criterion.eval() + + for subset in cfg.dataset.valid_subset.split(","): + try: + task.load_dataset(subset, combine=False, epoch=1) + dataset = task.dataset(subset) + except KeyError: + raise Exception("Cannot find dataset: " + subset) + + # Initialize data iterator + itr = task.get_batch_iterator( + dataset=dataset, + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + task.max_positions(), + *[m.max_positions() for m in models], + ), + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + prefix=f"valid on '{subset}' subset", + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + log_outputs = [] + for i, sample in enumerate(progress): + sample = utils.move_to_cuda(sample) if use_cuda else sample + _loss, _sample_size, log_output = task.valid_step(sample, model, criterion) + progress.log(log_output, step=i) + log_outputs.append(log_output) + + if cfg.distributed_training.distributed_world_size > 1: + log_outputs = distributed_utils.all_gather_list( + log_outputs, + max_size=cfg.common.all_gather_list_size, + ) + log_outputs = list(chain.from_iterable(log_outputs)) + + with metrics.aggregate() as agg: + task.reduce_metrics(log_outputs, criterion) + log_output = agg.get_smoothed_values() + + progress.print(log_output, tag=subset, step=i) + + +def cli_main(): + parser = options.get_validation_parser() + args = options.parse_args_and_arch(parser) + + # only override args that are explicitly given on the command line + override_parser = options.get_validation_parser() + override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) + + distributed_utils.call_main(convert_namespace_to_omegaconf(args), main, override_args=override_args) + + +if __name__ == "__main__": + cli_main() diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000..ce7d76c --- /dev/null +++ b/hubconf.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import importlib + +from fairseq.hub_utils import ( # noqa; noqa + BPEHubInterface as bpe, + TokenizerHubInterface as tokenizer, +) +from fairseq.models import MODEL_REGISTRY # noqa + + +dependencies = [ + "dataclasses", + "hydra", + "numpy", + "regex", + "requests", + "torch", +] + + +# Check for required dependencies and raise a RuntimeError if any are missing. +missing_deps = [] +for dep in dependencies: + try: + importlib.import_module(dep) + except ImportError: + # Hack: the hydra package is provided under the "hydra-core" name in + # pypi. We don't want the user mistakenly calling `pip install hydra` + # since that will install an unrelated package. + if dep == "hydra": + dep = "hydra-core" + missing_deps.append(dep) +if len(missing_deps) > 0: + raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) + + +# torch.hub doesn't build Cython components, so if they are not found then try +# to build them here +try: + import fairseq.data.token_block_utils_fast # noqa +except ImportError: + try: + import cython # noqa + import os + from setuptools import sandbox + + sandbox.run_setup( + os.path.join(os.path.dirname(__file__), "setup.py"), + ["build_ext", "--inplace"], + ) + except ImportError: + print( + "Unable to build Cython components. Please make sure Cython is " + "installed if the torch.hub model you are loading depends on it." + ) + + +# automatically expose models defined in FairseqModel::hub_models +for _model_type, _cls in MODEL_REGISTRY.items(): + for model_name in _cls.hub_models().keys(): + globals()[model_name] = functools.partial( + _cls.from_pretrained, + model_name, + ) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6d1b4c5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "wheel", "cython"] +build-backend = "setuptools.build_meta" diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/average_checkpoints.py b/scripts/average_checkpoints.py new file mode 100644 index 0000000..c512f80 --- /dev/null +++ b/scripts/average_checkpoints.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import collections +import os +import re + +import torch +from fairseq.file_io import PathManager + + +def average_checkpoints(inputs): + """Loads checkpoints from inputs and returns a model with averaged weights. + + Args: + inputs: An iterable of string paths of checkpoints to load from. + + Returns: + A dict of string keys mapping to various values. The 'model' key + from the returned dict should correspond to an OrderedDict mapping + string parameter names to torch Tensors. + """ + params_dict = collections.OrderedDict() + params_keys = None + new_state = None + num_models = len(inputs) + + for fpath in inputs: + with PathManager.open(fpath, "rb") as f: + state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), + ) + # Copies over the settings from the first checkpoint + if new_state is None: + new_state = state + + model_params = state["model"] + + model_params_keys = list(model_params.keys()) + if params_keys is None: + params_keys = model_params_keys + elif params_keys != model_params_keys: + raise KeyError( + "For checkpoint {}, expected list of params: {}, " + "but found: {}".format(f, params_keys, model_params_keys) + ) + + for k in params_keys: + p = model_params[k] + if isinstance(p, torch.HalfTensor): + p = p.float() + if k not in params_dict: + params_dict[k] = p.clone() + # NOTE: clone() is needed in case of p is a shared parameter + else: + params_dict[k] += p + + averaged_params = collections.OrderedDict() + for k, v in params_dict.items(): + averaged_params[k] = v + if averaged_params[k].is_floating_point(): + averaged_params[k].div_(num_models) + else: + averaged_params[k] //= num_models + new_state["model"] = averaged_params + return new_state + + +def last_n_checkpoints(paths, n, update_based, upper_bound=None): + assert len(paths) == 1 + path = paths[0] + if update_based: + pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt") + else: + pt_regexp = re.compile(r"checkpoint(\d+)\.pt") + files = PathManager.ls(path) + + entries = [] + for f in files: + m = pt_regexp.fullmatch(f) + if m is not None: + sort_key = int(m.group(1)) + if upper_bound is None or sort_key <= upper_bound: + entries.append((sort_key, m.group(0))) + if len(entries) < n: + raise Exception( + "Found {} checkpoint files but need at least {}", len(entries), n + ) + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] + + +def main(): + parser = argparse.ArgumentParser( + description="Tool to average the params of input checkpoints to " + "produce a new checkpoint", + ) + # fmt: off + parser.add_argument('--inputs', required=True, nargs='+', + help='Input checkpoint file paths.') + parser.add_argument('--output', required=True, metavar='FILE', + help='Write the new checkpoint containing the averaged weights to this path.') + num_group = parser.add_mutually_exclusive_group() + num_group.add_argument('--num-epoch-checkpoints', type=int, + help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' + 'and average last this many of them.') + num_group.add_argument('--num-update-checkpoints', type=int, + help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' + 'and average last this many of them.') + parser.add_argument('--checkpoint-upper-bound', type=int, + help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, ' + 'when using --num-update-checkpoints, this will set an upper bound on which update to use' + 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.' + 'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would be averaged assuming --save-interval-updates 500' + ) + # fmt: on + args = parser.parse_args() + print(args) + + num = None + is_update_based = False + if args.num_update_checkpoints is not None: + num = args.num_update_checkpoints + is_update_based = True + elif args.num_epoch_checkpoints is not None: + num = args.num_epoch_checkpoints + + assert args.checkpoint_upper_bound is None or ( + args.num_epoch_checkpoints is not None + or args.num_update_checkpoints is not None + ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints" + assert ( + args.num_epoch_checkpoints is None or args.num_update_checkpoints is None + ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints" + + if num is not None: + args.inputs = last_n_checkpoints( + args.inputs, + num, + is_update_based, + upper_bound=args.checkpoint_upper_bound, + ) + print("averaging checkpoints: ", args.inputs) + + new_state = average_checkpoints(args.inputs) + with PathManager.open(args.output, "wb") as f: + torch.save(new_state, f) + print("Finished writing averaged checkpoint to {}".format(args.output)) + + +if __name__ == "__main__": + main() diff --git a/scripts/build_sym_alignment.py b/scripts/build_sym_alignment.py new file mode 100644 index 0000000..0ca5c18 --- /dev/null +++ b/scripts/build_sym_alignment.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Use this script in order to build symmetric alignments for your translation +dataset. +This script depends on fast_align and mosesdecoder tools. You will need to +build those before running the script. +fast_align: + github: http://github.com/clab/fast_align + instructions: follow the instructions in README.md +mosesdecoder: + github: http://github.com/moses-smt/mosesdecoder + instructions: http://www.statmt.org/moses/?n=Development.GetStarted +The script produces the following files under --output_dir: + text.joined - concatenation of lines from the source_file and the + target_file. + align.forward - forward pass of fast_align. + align.backward - backward pass of fast_align. + aligned.sym_heuristic - symmetrized alignment. +""" + +import argparse +import os +from itertools import zip_longest + + +def main(): + parser = argparse.ArgumentParser(description="symmetric alignment builer") + # fmt: off + parser.add_argument('--fast_align_dir', + help='path to fast_align build directory') + parser.add_argument('--mosesdecoder_dir', + help='path to mosesdecoder root directory') + parser.add_argument('--sym_heuristic', + help='heuristic to use for symmetrization', + default='grow-diag-final-and') + parser.add_argument('--source_file', + help='path to a file with sentences ' + 'in the source language') + parser.add_argument('--target_file', + help='path to a file with sentences ' + 'in the target language') + parser.add_argument('--output_dir', + help='output directory') + # fmt: on + args = parser.parse_args() + + fast_align_bin = os.path.join(args.fast_align_dir, "fast_align") + symal_bin = os.path.join(args.mosesdecoder_dir, "bin", "symal") + sym_fast_align_bin = os.path.join( + args.mosesdecoder_dir, "scripts", "ems", "support", "symmetrize-fast-align.perl" + ) + + # create joined file + joined_file = os.path.join(args.output_dir, "text.joined") + with open(args.source_file, "r", encoding="utf-8") as src, open( + args.target_file, "r", encoding="utf-8" + ) as tgt: + with open(joined_file, "w", encoding="utf-8") as joined: + for s, t in zip_longest(src, tgt): + print("{} ||| {}".format(s.strip(), t.strip()), file=joined) + + bwd_align_file = os.path.join(args.output_dir, "align.backward") + + # run forward alignment + fwd_align_file = os.path.join(args.output_dir, "align.forward") + fwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v > {FWD}".format( + FASTALIGN=fast_align_bin, JOINED=joined_file, FWD=fwd_align_file + ) + assert os.system(fwd_fast_align_cmd) == 0 + + # run backward alignment + bwd_align_file = os.path.join(args.output_dir, "align.backward") + bwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}".format( + FASTALIGN=fast_align_bin, JOINED=joined_file, BWD=bwd_align_file + ) + assert os.system(bwd_fast_align_cmd) == 0 + + # run symmetrization + sym_out_file = os.path.join(args.output_dir, "aligned") + sym_cmd = "{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}".format( + SYMFASTALIGN=sym_fast_align_bin, + FWD=fwd_align_file, + BWD=bwd_align_file, + SRC=args.source_file, + TGT=args.target_file, + OUT=sym_out_file, + HEURISTIC=args.sym_heuristic, + SYMAL=symal_bin, + ) + assert os.system(sym_cmd) == 0 + + +if __name__ == "__main__": + main() diff --git a/scripts/compare_namespaces.py b/scripts/compare_namespaces.py new file mode 100644 index 0000000..bc24db6 --- /dev/null +++ b/scripts/compare_namespaces.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python +"""Helper script to compare two argparse.Namespace objects.""" + +from argparse import Namespace # noqa + + +def main(): + + ns1 = eval(input("Namespace 1: ")) + ns2 = eval(input("Namespace 2: ")) + + def keys(ns): + ks = set() + for k in dir(ns): + if not k.startswith("_"): + ks.add(k) + return ks + + k1 = keys(ns1) + k2 = keys(ns2) + + def print_keys(ks, ns1, ns2=None): + for k in ks: + if ns2 is None: + print("{}\t{}".format(k, getattr(ns1, k, None))) + else: + print( + "{}\t{}\t{}".format(k, getattr(ns1, k, None), getattr(ns2, k, None)) + ) + + print("Keys unique to namespace 1:") + print_keys(k1 - k2, ns1) + print() + + print("Keys unique to namespace 2:") + print_keys(k2 - k1, ns2) + print() + + print("Overlapping keys with different values:") + ks = [k for k in k1 & k2 if getattr(ns1, k, "None") != getattr(ns2, k, "None")] + print_keys(ks, ns1, ns2) + print() + + +if __name__ == "__main__": + main() diff --git a/scripts/compound_split_bleu.sh b/scripts/compound_split_bleu.sh new file mode 100644 index 0000000..1972fdd --- /dev/null +++ b/scripts/compound_split_bleu.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +if [ $# -ne 1 ]; then + echo "usage: $0 GENERATE_PY_OUTPUT" + exit 1 +fi + +GEN=$1 + +SYS=$GEN.sys +REF=$GEN.ref + +if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then + echo "not done generating" + exit +fi + +grep ^H $GEN | awk -F '\t' '{print $NF}' | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS +grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF +fairseq-score --sys $SYS --ref $REF diff --git a/scripts/constraints/extract.py b/scripts/constraints/extract.py new file mode 100644 index 0000000..f6155d0 --- /dev/null +++ b/scripts/constraints/extract.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Extracts random constraints from reference files.""" + +import argparse +import random +import sys + +from sacrebleu import extract_ngrams + + +def get_phrase(words, index, length): + assert index < len(words) - length + 1 + phr = " ".join(words[index : index + length]) + for i in range(index, index + length): + words.pop(index) + return phr + + +def main(args): + + if args.seed: + random.seed(args.seed) + + for line in sys.stdin: + constraints = [] + + def add_constraint(constraint): + constraints.append(constraint) + + source = line.rstrip() + if "\t" in line: + source, target = line.split("\t") + if args.add_sos: + target = f" {target}" + if args.add_eos: + target = f"{target} " + + if len(target.split()) >= args.len: + words = [target] + + num = args.number + + choices = {} + for i in range(num): + if len(words) == 0: + break + segmentno = random.choice(range(len(words))) + segment = words.pop(segmentno) + tokens = segment.split() + phrase_index = random.choice(range(len(tokens))) + choice = " ".join( + tokens[phrase_index : min(len(tokens), phrase_index + args.len)] + ) + for j in range( + phrase_index, min(len(tokens), phrase_index + args.len) + ): + tokens.pop(phrase_index) + if phrase_index > 0: + words.append(" ".join(tokens[0:phrase_index])) + if phrase_index + 1 < len(tokens): + words.append(" ".join(tokens[phrase_index:])) + choices[target.find(choice)] = choice + + # mask out with spaces + target = target.replace(choice, " " * len(choice), 1) + + for key in sorted(choices.keys()): + add_constraint(choices[key]) + + print(source, *constraints, sep="\t") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases") + parser.add_argument("--len", "-l", type=int, default=1, help="phrase length") + parser.add_argument( + "--add-sos", default=False, action="store_true", help="add token" + ) + parser.add_argument( + "--add-eos", default=False, action="store_true", help="add token" + ) + parser.add_argument("--seed", "-s", default=0, type=int) + args = parser.parse_args() + + main(args) diff --git a/scripts/constraints/validate.py b/scripts/constraints/validate.py new file mode 100644 index 0000000..d531ad9 --- /dev/null +++ b/scripts/constraints/validate.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + + +"""Reads in a fairseq output file, and verifies that the constraints +(C- lines) are present in the output (the first H- line). Assumes that +constraints are listed prior to the first hypothesis. +""" + +constraints = [] +found = 0 +total = 0 +for line in sys.stdin: + if line.startswith("C-"): + constraints.append(line.rstrip().split("\t")[1]) + elif line.startswith("H-"): + text = line.split("\t")[2] + + for constraint in constraints: + total += 1 + if constraint in text: + found += 1 + else: + print(f"No {constraint} in {text}", file=sys.stderr) + + constraints = [] + +print(f"Found {found} / {total} = {100 * found / total:.1f}%") diff --git a/scripts/convert_dictionary.lua b/scripts/convert_dictionary.lua new file mode 100644 index 0000000..14ee8c9 --- /dev/null +++ b/scripts/convert_dictionary.lua @@ -0,0 +1,34 @@ +-- Copyright (c) Facebook, Inc. and its affiliates. +-- +-- This source code is licensed under the MIT license found in the +-- LICENSE file in the root directory of this source tree. +-- +-- Usage: convert_dictionary.lua +require 'fairseq' +require 'torch' +require 'paths' + +if #arg < 1 then + print('usage: convert_dictionary.lua ') + os.exit(1) +end +if not paths.filep(arg[1]) then + print('error: file does not exit: ' .. arg[1]) + os.exit(1) +end + +dict = torch.load(arg[1]) +dst = paths.basename(arg[1]):gsub('.th7', '.txt') +assert(dst:match('.txt$')) + +f = io.open(dst, 'w') +for idx, symbol in ipairs(dict.index_to_symbol) do + if idx > dict.cutoff then + break + end + f:write(symbol) + f:write(' ') + f:write(dict.index_to_freq[idx]) + f:write('\n') +end +f:close() diff --git a/scripts/convert_model.lua b/scripts/convert_model.lua new file mode 100644 index 0000000..61b9213 --- /dev/null +++ b/scripts/convert_model.lua @@ -0,0 +1,108 @@ +-- Copyright (c) Facebook, Inc. and its affiliates. +-- +-- This source code is licensed under the MIT license found in the +-- LICENSE file in the root directory of this source tree. +-- +-- Usage: convert_model.lua +require 'torch' +local fairseq = require 'fairseq' + +model = torch.load(arg[1]) + +function find_weight_norm(container, module) + for _, wn in ipairs(container:listModules()) do + if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then + return wn + end + end +end + +function push_state(dict, key, module) + if torch.type(module) == 'nn.Linear' then + local wn = find_weight_norm(model.module, module) + assert(wn) + dict[key .. '.weight_v'] = wn.v:float() + dict[key .. '.weight_g'] = wn.g:float() + elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then + local wn = find_weight_norm(model.module, module) + assert(wn) + local v = wn.v:float():view(wn.viewOut):transpose(2, 3) + dict[key .. '.weight_v'] = v + dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1) + else + dict[key .. '.weight'] = module.weight:float() + end + if module.bias then + dict[key .. '.bias'] = module.bias:float() + end +end + +encoder_dict = {} +decoder_dict = {} +combined_dict = {} + +function encoder_state(encoder) + luts = encoder:findModules('nn.LookupTable') + push_state(encoder_dict, 'embed_tokens', luts[1]) + push_state(encoder_dict, 'embed_positions', luts[2]) + + fcs = encoder:findModules('nn.Linear') + assert(#fcs >= 2) + local nInputPlane = fcs[1].weight:size(1) + push_state(encoder_dict, 'fc1', table.remove(fcs, 1)) + push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs)) + + for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do + push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module) + if nInputPlane ~= module.weight:size(3) / 2 then + push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) + end + nInputPlane = module.weight:size(3) / 2 + end + assert(#fcs == 0) +end + +function decoder_state(decoder) + luts = decoder:findModules('nn.LookupTable') + push_state(decoder_dict, 'embed_tokens', luts[1]) + push_state(decoder_dict, 'embed_positions', luts[2]) + + fcs = decoder:findModules('nn.Linear') + local nInputPlane = fcs[1].weight:size(1) + push_state(decoder_dict, 'fc1', table.remove(fcs, 1)) + push_state(decoder_dict, 'fc2', fcs[#fcs - 1]) + push_state(decoder_dict, 'fc3', fcs[#fcs]) + + table.remove(fcs, #fcs) + table.remove(fcs, #fcs) + + for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do + if nInputPlane ~= module.weight:size(3) / 2 then + push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1)) + end + nInputPlane = module.weight:size(3) / 2 + + local prefix = 'attention.' .. tostring(i - 1) + push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1)) + push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1)) + push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module) + end + assert(#fcs == 0) +end + + +_encoder = model.module.modules[2] +_decoder = model.module.modules[3] + +encoder_state(_encoder) +decoder_state(_decoder) + +for k, v in pairs(encoder_dict) do + combined_dict['encoder.' .. k] = v +end +for k, v in pairs(decoder_dict) do + combined_dict['decoder.' .. k] = v +end + + +torch.save('state_dict.t7', combined_dict) diff --git a/scripts/count_docs.py b/scripts/count_docs.py new file mode 100644 index 0000000..58d85af --- /dev/null +++ b/scripts/count_docs.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Count the number of documents and average number of lines and tokens per +document in a large file. Documents should be separated by a single empty line. +""" + +import argparse +import gzip +import sys + +import numpy as np + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input") + parser.add_argument("--gzip", action="store_true") + args = parser.parse_args() + + def gopen(): + if args.gzip: + return gzip.open(args.input, "r") + else: + return open(args.input, "r", encoding="utf-8") + + num_lines = [] + num_toks = [] + with gopen() as h: + num_docs = 1 + num_lines_in_doc = 0 + num_toks_in_doc = 0 + for i, line in enumerate(h): + if len(line.strip()) == 0: # empty line indicates new document + num_docs += 1 + num_lines.append(num_lines_in_doc) + num_toks.append(num_toks_in_doc) + num_lines_in_doc = 0 + num_toks_in_doc = 0 + else: + num_lines_in_doc += 1 + num_toks_in_doc += len(line.rstrip().split()) + if i % 1000000 == 0: + print(i, file=sys.stderr, end="", flush=True) + elif i % 100000 == 0: + print(".", file=sys.stderr, end="", flush=True) + print(file=sys.stderr, flush=True) + + print("found {} docs".format(num_docs)) + print("average num lines per doc: {}".format(np.mean(num_lines))) + print("average num toks per doc: {}".format(np.mean(num_toks))) + + +if __name__ == "__main__": + main() diff --git a/scripts/read_binarized.py b/scripts/read_binarized.py new file mode 100644 index 0000000..a414095 --- /dev/null +++ b/scripts/read_binarized.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +from fairseq.data import Dictionary, data_utils, indexed_dataset + + +def get_parser(): + parser = argparse.ArgumentParser( + description="writes text from binarized file to stdout" + ) + # fmt: off + parser.add_argument('--dataset-impl', help='dataset implementation', + choices=indexed_dataset.get_available_dataset_impl()) + parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None) + parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') + # fmt: on + + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + dictionary = Dictionary.load(args.dict) if args.dict is not None else None + dataset = data_utils.load_indexed_dataset( + args.input, + dictionary, + dataset_impl=args.dataset_impl, + default="lazy", + ) + + for tensor_line in dataset: + if dictionary is None: + line = " ".join([str(int(x)) for x in tensor_line]) + else: + line = dictionary.string(tensor_line) + + print(line) + + +if __name__ == "__main__": + main() diff --git a/scripts/rm_pt.py b/scripts/rm_pt.py new file mode 100644 index 0000000..6cd063d --- /dev/null +++ b/scripts/rm_pt.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import re +import shutil +import sys + + +pt_regexp = re.compile(r"checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt") +pt_regexp_epoch_based = re.compile(r"checkpoint(\d+)\.pt") +pt_regexp_update_based = re.compile(r"checkpoint_\d+_(\d+)\.pt") + + +def parse_checkpoints(files): + entries = [] + for f in files: + m = pt_regexp_epoch_based.fullmatch(f) + if m is not None: + entries.append((int(m.group(1)), m.group(0))) + else: + m = pt_regexp_update_based.fullmatch(f) + if m is not None: + entries.append((int(m.group(1)), m.group(0))) + return entries + + +def last_n_checkpoints(files, n): + entries = parse_checkpoints(files) + return [x[1] for x in sorted(entries, reverse=True)[:n]] + + +def every_n_checkpoints(files, n): + entries = parse_checkpoints(files) + return [x[1] for x in sorted(sorted(entries)[::-n])] + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Recursively delete checkpoint files from `root_dir`, " + "but preserve checkpoint_best.pt and checkpoint_last.pt" + ) + ) + parser.add_argument("root_dirs", nargs="*") + parser.add_argument( + "--save-last", type=int, default=0, help="number of last checkpoints to save" + ) + parser.add_argument( + "--save-every", type=int, default=0, help="interval of checkpoints to save" + ) + parser.add_argument( + "--preserve-test", + action="store_true", + help="preserve checkpoints in dirs that start with test_ prefix (default: delete them)", + ) + parser.add_argument( + "--delete-best", action="store_true", help="delete checkpoint_best.pt" + ) + parser.add_argument( + "--delete-last", action="store_true", help="delete checkpoint_last.pt" + ) + parser.add_argument( + "--no-dereference", action="store_true", help="don't dereference symlinks" + ) + args = parser.parse_args() + + files_to_desymlink = [] + files_to_preserve = [] + files_to_delete = [] + for root_dir in args.root_dirs: + for root, _subdirs, files in os.walk(root_dir): + if args.save_last > 0: + to_save = last_n_checkpoints(files, args.save_last) + else: + to_save = [] + if args.save_every > 0: + to_save += every_n_checkpoints(files, args.save_every) + for file in files: + if not pt_regexp.fullmatch(file): + continue + full_path = os.path.join(root, file) + if ( + not os.path.basename(root).startswith("test_") or args.preserve_test + ) and ( + (file == "checkpoint_last.pt" and not args.delete_last) + or (file == "checkpoint_best.pt" and not args.delete_best) + or file in to_save + ): + if os.path.islink(full_path) and not args.no_dereference: + files_to_desymlink.append(full_path) + else: + files_to_preserve.append(full_path) + else: + files_to_delete.append(full_path) + + if len(files_to_desymlink) == 0 and len(files_to_delete) == 0: + print("Nothing to do.") + sys.exit(0) + + files_to_desymlink = sorted(files_to_desymlink) + files_to_preserve = sorted(files_to_preserve) + files_to_delete = sorted(files_to_delete) + + print("Operations to perform (in order):") + if len(files_to_desymlink) > 0: + for file in files_to_desymlink: + print(" - preserve (and dereference symlink): " + file) + if len(files_to_preserve) > 0: + for file in files_to_preserve: + print(" - preserve: " + file) + if len(files_to_delete) > 0: + for file in files_to_delete: + print(" - delete: " + file) + while True: + resp = input("Continue? (Y/N): ") + if resp.strip().lower() == "y": + break + elif resp.strip().lower() == "n": + sys.exit(0) + + print("Executing...") + if len(files_to_desymlink) > 0: + for file in files_to_desymlink: + realpath = os.path.realpath(file) + print("rm " + file) + os.remove(file) + print("cp {} {}".format(realpath, file)) + shutil.copyfile(realpath, file) + if len(files_to_delete) > 0: + for file in files_to_delete: + print("rm " + file) + os.remove(file) + + +if __name__ == "__main__": + main() diff --git a/scripts/sacrebleu.sh b/scripts/sacrebleu.sh new file mode 100644 index 0000000..c10bf2b --- /dev/null +++ b/scripts/sacrebleu.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +if [ $# -ne 4 ]; then + echo "usage: $0 TESTSET SRCLANG TGTLANG GEN" + exit 1 +fi + +TESTSET=$1 +SRCLANG=$2 +TGTLANG=$3 + +GEN=$4 + +if ! command -v sacremoses &> /dev/null +then + echo "sacremoses could not be found, please install with: pip install sacremoses" + exit +fi + +grep ^H $GEN \ +| sed 's/^H\-//' \ +| sort -n -k 1 \ +| cut -f 3 \ +| sacremoses detokenize \ +> $GEN.sorted.detok + +sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok diff --git a/scripts/shard_docs.py b/scripts/shard_docs.py new file mode 100644 index 0000000..97232c3 --- /dev/null +++ b/scripts/shard_docs.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Split a large file into shards while respecting document boundaries. Documents +should be separated by a single empty line. +""" + +import argparse +import contextlib + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input") + parser.add_argument("--num-shards", type=int) + args = parser.parse_args() + + assert args.num_shards is not None and args.num_shards > 1 + + with open(args.input, "r", encoding="utf-8") as h: + with contextlib.ExitStack() as stack: + outputs = [ + stack.enter_context( + open(args.input + ".shard" + str(i), "w", encoding="utf-8") + ) + for i in range(args.num_shards) + ] + + doc = [] + first_doc = [True] * args.num_shards + + def output_doc(i): + if not first_doc[i]: + outputs[i].write("\n") + first_doc[i] = False + for line in doc: + outputs[i].write(line) + doc.clear() + + num_docs = 0 + for line in h: + if line.strip() == "": # empty line indicates new document + output_doc(num_docs % args.num_shards) + num_docs += 1 + else: + doc.append(line) + output_doc(num_docs % args.num_shards) + + +if __name__ == "__main__": + main() diff --git a/scripts/split_train_valid_docs.py b/scripts/split_train_valid_docs.py new file mode 100644 index 0000000..ff15978 --- /dev/null +++ b/scripts/split_train_valid_docs.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Split a large file into a train and valid set while respecting document +boundaries. Documents should be separated by a single empty line. +""" + +import argparse +import random +import sys + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("input") + parser.add_argument("sample_output", help="train output file") + parser.add_argument("remainder_output", help="valid output file") + parser.add_argument("-k", type=int, help="remainder size") + parser.add_argument( + "--lines", action="store_true", help="split lines instead of docs" + ) + args = parser.parse_args() + + assert args.k is not None + + sample = [] + remainder = [] + num_docs = [0] + + def update_sample(doc): + if len(sample) < args.k: + sample.append(doc.copy()) + else: + i = num_docs[0] + j = random.randrange(i + 1) + if j < args.k: + remainder.append(sample[j]) + sample[j] = doc.copy() + else: + remainder.append(doc.copy()) + num_docs[0] += 1 + doc.clear() + + with open(args.input, "r", encoding="utf-8") as h: + doc = [] + for i, line in enumerate(h): + if line.strip() == "": # empty line indicates new document + update_sample(doc) + else: + doc.append(line) + if args.lines: + update_sample(doc) + if i % 1000000 == 0: + print(i, file=sys.stderr, end="", flush=True) + elif i % 100000 == 0: + print(".", file=sys.stderr, end="", flush=True) + if len(doc) > 0: + update_sample(doc) + print(file=sys.stderr, flush=True) + + assert len(sample) == args.k + + with open(args.sample_output, "w", encoding="utf-8") as out: + first = True + for doc in sample: + if not first and not args.lines: + out.write("\n") + first = False + for line in doc: + out.write(line) + + with open(args.remainder_output, "w", encoding="utf-8") as out: + first = True + for doc in remainder: + if not first and not args.lines: + out.write("\n") + first = False + for line in doc: + out.write(line) + + +if __name__ == "__main__": + main() diff --git a/scripts/spm_decode.py b/scripts/spm_decode.py new file mode 100644 index 0000000..1c18b1d --- /dev/null +++ b/scripts/spm_decode.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse + +import sentencepiece as spm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", required=True, help="sentencepiece model to use for decoding" + ) + parser.add_argument("--input", required=True, help="input file to decode") + parser.add_argument("--input_format", choices=["piece", "id"], default="piece") + args = parser.parse_args() + + sp = spm.SentencePieceProcessor() + sp.Load(args.model) + + if args.input_format == "piece": + + def decode(l): + return "".join(sp.DecodePieces(l)) + + elif args.input_format == "id": + + def decode(l): + return "".join(sp.DecodeIds(l)) + + else: + raise NotImplementedError + + def tok2int(tok): + # remap reference-side (represented as <>) to 0 + return int(tok) if tok != "<>" else 0 + + with open(args.input, "r", encoding="utf-8") as h: + for line in h: + if args.input_format == "id": + print(decode(list(map(tok2int, line.rstrip().split())))) + elif args.input_format == "piece": + print(decode(line.rstrip().split())) + + +if __name__ == "__main__": + main() diff --git a/scripts/spm_encode.py b/scripts/spm_encode.py new file mode 100644 index 0000000..83facfb --- /dev/null +++ b/scripts/spm_encode.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import contextlib +import sys + +import sentencepiece as spm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", required=True, help="sentencepiece model to use for encoding" + ) + parser.add_argument( + "--inputs", nargs="+", default=["-"], help="input files to filter/encode" + ) + parser.add_argument( + "--outputs", nargs="+", default=["-"], help="path to save encoded outputs" + ) + parser.add_argument("--output_format", choices=["piece", "id"], default="piece") + parser.add_argument( + "--min-len", + type=int, + metavar="N", + help="filter sentence pairs with fewer than N tokens", + ) + parser.add_argument( + "--max-len", + type=int, + metavar="N", + help="filter sentence pairs with more than N tokens", + ) + args = parser.parse_args() + + assert len(args.inputs) == len( + args.outputs + ), "number of input and output paths should match" + + sp = spm.SentencePieceProcessor() + sp.Load(args.model) + + if args.output_format == "piece": + + def encode(l): + return sp.EncodeAsPieces(l) + + elif args.output_format == "id": + + def encode(l): + return list(map(str, sp.EncodeAsIds(l))) + + else: + raise NotImplementedError + + if args.min_len is not None or args.max_len is not None: + + def valid(line): + return (args.min_len is None or len(line) >= args.min_len) and ( + args.max_len is None or len(line) <= args.max_len + ) + + else: + + def valid(lines): + return True + + with contextlib.ExitStack() as stack: + inputs = [ + stack.enter_context(open(input, "r", encoding="utf-8")) + if input != "-" + else sys.stdin + for input in args.inputs + ] + outputs = [ + stack.enter_context(open(output, "w", encoding="utf-8")) + if output != "-" + else sys.stdout + for output in args.outputs + ] + + stats = { + "num_empty": 0, + "num_filtered": 0, + } + + def encode_line(line): + line = line.strip() + if len(line) > 0: + line = encode(line) + if valid(line): + return line + else: + stats["num_filtered"] += 1 + else: + stats["num_empty"] += 1 + return None + + for i, lines in enumerate(zip(*inputs), start=1): + enc_lines = list(map(encode_line, lines)) + if not any(enc_line is None for enc_line in enc_lines): + for enc_line, output_h in zip(enc_lines, outputs): + print(" ".join(enc_line), file=output_h) + if i % 10000 == 0: + print("processed {} lines".format(i), file=sys.stderr) + + print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr) + print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/scripts/spm_train.py b/scripts/spm_train.py new file mode 100644 index 0000000..9db668f --- /dev/null +++ b/scripts/spm_train.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import sys + +import sentencepiece as spm + + +if __name__ == "__main__": + spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:])) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..7b13f13 --- /dev/null +++ b/setup.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import subprocess +import sys +from setuptools import setup, find_packages, Extension + +from setuptools import Extension, find_packages, setup + + +if sys.version_info < (3, 6): + sys.exit("Sorry, Python >= 3.6 is required for fairseq.") + + +def write_version_py(): + with open(os.path.join("fairseq", "version.txt")) as f: + version = f.read().strip() + + # append latest commit hash to version string + try: + sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() + version += "+" + sha[:7] + except Exception: + pass + + # write version info to fairseq/version.py + with open(os.path.join("fairseq", "version.py"), "w") as f: + f.write("__version__ = \"{}\"\n".format(version)) + return version + + +version = write_version_py() + + +with open("README.md") as f: + readme = f.read() + + +if sys.platform == "darwin": + extra_compile_args = ["-stdlib=libc++", "-O3"] +else: + extra_compile_args = ["-std=c++11", "-O3"] + + +class NumpyExtension(Extension): + """Source: https://stackoverflow.com/a/54128391""" + + def __init__(self, *args, **kwargs): + self.__include_dirs = [] + super().__init__(*args, **kwargs) + + @property + def include_dirs(self): + import numpy + + return self.__include_dirs + [numpy.get_include()] + + @include_dirs.setter + def include_dirs(self, dirs): + self.__include_dirs = dirs + + +extensions = [ + Extension( + "fairseq.libbleu", + sources=[ + "fairseq/clib/libbleu/libbleu.cpp", + "fairseq/clib/libbleu/module.cpp", + ], + extra_compile_args=extra_compile_args, + ), + NumpyExtension( + "fairseq.data.data_utils_fast", + sources=["fairseq/data/data_utils_fast.pyx"], + language="c++", + extra_compile_args=extra_compile_args, + ), + NumpyExtension( + "fairseq.data.token_block_utils_fast", + sources=["fairseq/data/token_block_utils_fast.pyx"], + language="c++", + extra_compile_args=extra_compile_args, + ), +] + + +cmdclass = {} + + +try: + # torch is not available when generating docs + from torch.utils import cpp_extension + + extensions.extend( + [ + cpp_extension.CppExtension( + "fairseq.libnat", + sources=[ + "fairseq/clib/libnat/edit_dist.cpp", + ], + ) + ] + ) + + if "CUDA_HOME" in os.environ: + extensions.extend( + [ + cpp_extension.CppExtension( + "fairseq.libnat_cuda", + sources=[ + "fairseq/clib/libnat_cuda/edit_dist.cu", + "fairseq/clib/libnat_cuda/binding.cpp", + ], + ) + ] + ) + cmdclass["build_ext"] = cpp_extension.BuildExtension + +except ImportError: + pass + + +if "READTHEDOCS" in os.environ: + # don't build extensions when generating docs + extensions = [] + if "build_ext" in cmdclass: + del cmdclass["build_ext"] + + # use CPU build of PyTorch + dependency_links = [ + "https://download.pytorch.org/whl/cpu/torch-1.3.0%2Bcpu-cp36-cp36m-linux_x86_64.whl" + ] +else: + dependency_links = [] + + +if "clean" in sys.argv[1:]: + # Source: https://bit.ly/2NLVsgE + print("deleting Cython files...") + import subprocess + + subprocess.run( + ["rm -f fairseq/*.so fairseq/**/*.so fairseq/*.pyd fairseq/**/*.pyd"], + shell=True, + ) + + +def do_setup(package_data): + setup( + name="fairseq", + version=version, + description="Facebook AI Research Sequence-to-Sequence Toolkit", + url="https://github.com/pytorch/fairseq", + classifiers=[ + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.6", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + long_description=readme, + long_description_content_type="text/markdown", + setup_requires=[ + "cython", + "numpy", + "setuptools>=18.0", + ], + install_requires=[ + "cffi", + "cython", + "dataclasses", + "editdistance", + "hydra-core", + "numpy", + "regex", + "sacrebleu>=1.4.12", + "torch", + "tqdm", + ], + dependency_links=dependency_links, + packages=find_packages( + exclude=[ + "examples", + "examples.*", + "scripts", + "scripts.*", + "tests", + "tests.*", + ] + ), + package_data=package_data, + ext_modules=extensions, + test_suite="tests", + entry_points={ + "console_scripts": [ + "fairseq-eval-lm = fairseq_cli.eval_lm:cli_main", + "fairseq-generate = fairseq_cli.generate:cli_main", + "fairseq-interactive = fairseq_cli.interactive:cli_main", + "fairseq-preprocess = fairseq_cli.preprocess:cli_main", + "fairseq-score = fairseq_cli.score:cli_main", + "fairseq-train = fairseq_cli.train:cli_main", + "fairseq-validate = fairseq_cli.validate:cli_main", + ], + }, + cmdclass=cmdclass, + zip_safe=False, + ) + + +def get_files(path, relative_to="fairseq"): + all_files = [] + for root, _dirs, files in os.walk(path, followlinks=True): + root = os.path.relpath(root, relative_to) + for file in files: + if file.endswith(".pyc"): + continue + all_files.append(os.path.join(root, file)) + return all_files + + +try: + # symlink config and examples into fairseq package so package_data accepts them + if "build_ext" not in sys.argv[1:]: + os.symlink(os.path.join("..", "config"), "fairseq/config") + os.symlink(os.path.join("..", "examples"), "fairseq/examples") + package_data = { + "fairseq": get_files("fairseq/config") + get_files("fairseq/examples"), + } + do_setup(package_data) +finally: + if "build_ext" not in sys.argv[1:]: + os.unlink("fairseq/config") + os.unlink("fairseq/examples") diff --git a/train.py b/train.py new file mode 100644 index 0000000..321de3d --- /dev/null +++ b/train.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Legacy entry point. Use fairseq_cli/train.py or fairseq-train instead. +""" + +from fairseq_cli.train import cli_main + + +if __name__ == "__main__": + cli_main()