diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index cc8ea89f0..28a33091f 100755 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -406,12 +406,15 @@ def main(): # Training if training_args.do_train: + checkpoint = None if last_checkpoint is not None: checkpoint = last_checkpoint elif os.path.isdir(model_args.model_name_or_path): - checkpoint = model_args.model_name_or_path - else: - checkpoint = None + # Check the config from that potential checkpoint has the right number of labels before using it as a + # checkpoint. + if AutoConfig.from_pretrained(model_args.model_name_or_path).num_labels == num_labels: + checkpoint = model_args.model_name_or_path + train_result = trainer.train(resume_from_checkpoint=checkpoint) metrics = train_result.metrics