inhibit copying base model dependencies for summ, and translation (#3285)

* inhibit copying base model dependencies for summ, and translation

* comments fix

---------

Co-authored-by: Anubha Jain <anubhajain@microsoft.com>
This commit is contained in:
Anubha Jain 2024-08-22 14:00:51 +05:30 коммит произвёл GitHub
Родитель 73db714cd7
Коммит ae510a0de7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 10 добавлений и 10 удалений

Просмотреть файл

@ -20,7 +20,7 @@ from azureml.acft.common_components.utils.mlflow_utils import update_acft_metada
from azureml.acft.contrib.hf.nlp.utils.common_utils import deep_update
from azureml.acft.contrib.hf.nlp.constants.constants import (
MLFlowHFFlavourConstants, MLFlowHFFlavourTasks, SaveFileConstants, HfModelTypes
MLFlowHFFlavourConstants, MLFlowHFFlavourTasks, SaveFileConstants, HfModelTypes, Tasks
)
import mlflow
@ -269,10 +269,10 @@ class Pytorch_to_OSS_MlFlow_ModelConverter(ModelConverter, PyTorch_to_MlFlow_Mod
yaml.safe_dump(conda_dict, f)
logger.info("Updated conda.yaml file")
def is_t5_text_classification_finetune(self, model_type) -> bool:
"""Check for t5 text-classification."""
return self.mlflow_task_type == MLFlowHFFlavourTasks.SINGLE_LABEL_CLASSIFICATION and \
model_type == HfModelTypes.T5
def is_t5_finetune(self, model_type) -> bool:
"""Check for t5 text-classification, translation, summarization."""
return self.component_args.task_name in [Tasks.SINGLE_LABEL_CLASSIFICATION, Tasks.TRANSLATION,
Tasks.SUMMARIZATION] and model_type == HfModelTypes.T5
def convert_model(self) -> None:
"""Convert pytorch model to oss mlflow model."""
@ -282,9 +282,9 @@ class Pytorch_to_OSS_MlFlow_ModelConverter(ModelConverter, PyTorch_to_MlFlow_Mod
self.set_mlflow_model_parameters(model)
# Temp Fix:
# specific check for t5 text-classification so that base model dependencies doesn't get pass
# and use transformers version 4.40.0 from infer dependencies
if not self.is_t5_text_classification_finetune(model.config.model_type):
# specific check for t5 text-classification, translation, summarization so that base model
# dependencies doesn't get pass and use transformers version 4.44.0 from infer dependencies
if not self.is_t5_finetune(model.config.model_type):
conda_file_path = Path(self.ft_pytorch_model_path, MLFlowHFFlavourConstants.CONDA_YAML_FILE)
if conda_file_path.is_file():
self.mlflow_save_model_kwargs.update({"conda_env": str(conda_file_path)})
@ -316,8 +316,8 @@ class Pytorch_to_OSS_MlFlow_ModelConverter(ModelConverter, PyTorch_to_MlFlow_Mod
self.add_model_signature()
self.copy_finetune_config(self.ft_pytorch_model_path, self.mlflow_model_save_path)
# Temp fix for t5 text-classification
if self.is_t5_text_classification_finetune(model.config.model_type):
# Temp fix for t5 text-classification, translation, summarization
if self.is_t5_finetune(model.config.model_type):
self.remove_unwanted_packages(self.mlflow_model_save_path)
logger.info("Saved MLFlow model using OSS flavour.")