зеркало из https://github.com/microsoft/archai.git
chore(scripts): Adds fine-tune option to DeepSpeed training script.
This commit is contained in:
Родитель
fe957c5236
Коммит
bf22f26fc1
|
@ -55,6 +55,12 @@ deepspeed deepspeed/train_codegen.py --help
|
|||
|
||||
You can customize the training by modifying the arguments defined in `CodeGenFlashConfig`, `DsTrainingArguments`, and `ds_config.json`. By default, the arguments are set to perform a toy training and explain how the pipeline works.
|
||||
|
||||
Additionally, if you have a model that has been previously trained with DeepSpeed, you can continue its training or fine-tune as follows:
|
||||
|
||||
```bash
|
||||
deepspeed deepspeed/train_codegen.py --pre_trained_model_path <path_to_checkpoint>
|
||||
```
|
||||
|
||||
## Hugging Face
|
||||
|
||||
If you are using Hugging Face, run the following command to begin training:
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from archai.datasets.nlp.fast_hf_dataset_provider import (
|
||||
|
@ -12,14 +14,13 @@ from archai.datasets.nlp.fast_hf_dataset_provider import (
|
|||
from archai.discrete_search.search_spaces.nlp.tfpp.modeling_codegen_flash import (
|
||||
CodeGenFlashConfig,
|
||||
CodeGenFlashSequential,
|
||||
LMHeadLoss,
|
||||
)
|
||||
from archai.trainers.nlp.ds_trainer import DsTrainer
|
||||
from archai.trainers.nlp.ds_training_args import DsTrainingArguments
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Trains a CodeGen model with DeepSpeed.")
|
||||
parser = argparse.ArgumentParser(description="Trains/Fine-tunes a CodeGen model with DeepSpeed.")
|
||||
|
||||
parser.add_argument(
|
||||
"-dn",
|
||||
|
@ -37,6 +38,14 @@ def parse_args() -> argparse.Namespace:
|
|||
help="Configuration name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-ptm",
|
||||
"--pre_trained_model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the pre-trained model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-ds",
|
||||
"--ds_config_path",
|
||||
|
@ -103,6 +112,15 @@ if __name__ == "__main__":
|
|||
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
if args.pre_trained_model_path is not None:
|
||||
state_dict = torch.load(os.path.join(args.pre_trained_model_path, "mp_rank_00_model_states.pt"))
|
||||
if state_dict["module"] is not None:
|
||||
model.load_state_dict(state_dict["module"])
|
||||
else:
|
||||
for i, layer in enumerate(model.layers):
|
||||
state_dict = torch.load(os.path.join(args.pre_trained_model_path, f"layer_{i:02d}-model_states.pt"))
|
||||
layer.load_state_dict(state_dict)
|
||||
|
||||
training_args = DsTrainingArguments(
|
||||
"ds-codegen",
|
||||
ds_config=args.ds_config_path,
|
||||
|
|
Загрузка…
Ссылка в новой задаче