diff --git a/scripts/trainers/README.md b/scripts/trainers/README.md index 4f91cd3d..4d40ad1b 100644 --- a/scripts/trainers/README.md +++ b/scripts/trainers/README.md @@ -55,6 +55,12 @@ deepspeed deepspeed/train_codegen.py --help You can customize the training by modifying the arguments defined in `CodeGenFlashConfig`, `DsTrainingArguments`, and `ds_config.json`. By default, the arguments are set to perform a toy training and explain how the pipeline works. +Additionally, if you have a model that has been previously trained with DeepSpeed, you can continue its training or fine-tune as follows: + +```bash +deepspeed deepspeed/train_codegen.py --pre_trained_model_path +``` + ## Hugging Face If you are using Hugging Face, run the following command to begin training: diff --git a/scripts/trainers/deepspeed/train_codegen.py b/scripts/trainers/deepspeed/train_codegen.py index 39533ec8..97f77db2 100644 --- a/scripts/trainers/deepspeed/train_codegen.py +++ b/scripts/trainers/deepspeed/train_codegen.py @@ -2,7 +2,9 @@ # Licensed under the MIT license. import argparse +import os +import torch from transformers import AutoTokenizer from archai.datasets.nlp.fast_hf_dataset_provider import ( @@ -12,14 +14,13 @@ from archai.datasets.nlp.fast_hf_dataset_provider import ( from archai.discrete_search.search_spaces.nlp.tfpp.modeling_codegen_flash import ( CodeGenFlashConfig, CodeGenFlashSequential, - LMHeadLoss, ) from archai.trainers.nlp.ds_trainer import DsTrainer from archai.trainers.nlp.ds_training_args import DsTrainingArguments def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Trains a CodeGen model with DeepSpeed.") + parser = argparse.ArgumentParser(description="Trains/Fine-tunes a CodeGen model with DeepSpeed.") parser.add_argument( "-dn", @@ -37,6 +38,14 @@ def parse_args() -> argparse.Namespace: help="Configuration name of the dataset to use (via the datasets library).", ) + parser.add_argument( + "-ptm", + "--pre_trained_model_path", + type=str, + default=None, + help="Path to the pre-trained model.", + ) + parser.add_argument( "-ds", "--ds_config_path", @@ -103,6 +112,15 @@ if __name__ == "__main__": print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") + if args.pre_trained_model_path is not None: + state_dict = torch.load(os.path.join(args.pre_trained_model_path, "mp_rank_00_model_states.pt")) + if state_dict["module"] is not None: + model.load_state_dict(state_dict["module"]) + else: + for i, layer in enumerate(model.layers): + state_dict = torch.load(os.path.join(args.pre_trained_model_path, f"layer_{i:02d}-model_states.pt")) + layer.load_state_dict(state_dict) + training_args = DsTrainingArguments( "ds-codegen", ds_config=args.ds_config_path,