inhibit copying base model dependencies for summ, and translation (#3285)
* inhibit copying base model dependencies for summ, and translation * comments fix --------- Co-authored-by: Anubha Jain <anubhajain@microsoft.com>
This commit is contained in:
Родитель
73db714cd7
Коммит
ae510a0de7
|
@ -20,7 +20,7 @@ from azureml.acft.common_components.utils.mlflow_utils import update_acft_metada
|
|||
|
||||
from azureml.acft.contrib.hf.nlp.utils.common_utils import deep_update
|
||||
from azureml.acft.contrib.hf.nlp.constants.constants import (
|
||||
MLFlowHFFlavourConstants, MLFlowHFFlavourTasks, SaveFileConstants, HfModelTypes
|
||||
MLFlowHFFlavourConstants, MLFlowHFFlavourTasks, SaveFileConstants, HfModelTypes, Tasks
|
||||
)
|
||||
|
||||
import mlflow
|
||||
|
@ -269,10 +269,10 @@ class Pytorch_to_OSS_MlFlow_ModelConverter(ModelConverter, PyTorch_to_MlFlow_Mod
|
|||
yaml.safe_dump(conda_dict, f)
|
||||
logger.info("Updated conda.yaml file")
|
||||
|
||||
def is_t5_text_classification_finetune(self, model_type) -> bool:
|
||||
"""Check for t5 text-classification."""
|
||||
return self.mlflow_task_type == MLFlowHFFlavourTasks.SINGLE_LABEL_CLASSIFICATION and \
|
||||
model_type == HfModelTypes.T5
|
||||
def is_t5_finetune(self, model_type) -> bool:
|
||||
"""Check for t5 text-classification, translation, summarization."""
|
||||
return self.component_args.task_name in [Tasks.SINGLE_LABEL_CLASSIFICATION, Tasks.TRANSLATION,
|
||||
Tasks.SUMMARIZATION] and model_type == HfModelTypes.T5
|
||||
|
||||
def convert_model(self) -> None:
|
||||
"""Convert pytorch model to oss mlflow model."""
|
||||
|
@ -282,9 +282,9 @@ class Pytorch_to_OSS_MlFlow_ModelConverter(ModelConverter, PyTorch_to_MlFlow_Mod
|
|||
self.set_mlflow_model_parameters(model)
|
||||
|
||||
# Temp Fix:
|
||||
# specific check for t5 text-classification so that base model dependencies doesn't get pass
|
||||
# and use transformers version 4.40.0 from infer dependencies
|
||||
if not self.is_t5_text_classification_finetune(model.config.model_type):
|
||||
# specific check for t5 text-classification, translation, summarization so that base model
|
||||
# dependencies doesn't get pass and use transformers version 4.44.0 from infer dependencies
|
||||
if not self.is_t5_finetune(model.config.model_type):
|
||||
conda_file_path = Path(self.ft_pytorch_model_path, MLFlowHFFlavourConstants.CONDA_YAML_FILE)
|
||||
if conda_file_path.is_file():
|
||||
self.mlflow_save_model_kwargs.update({"conda_env": str(conda_file_path)})
|
||||
|
@ -316,8 +316,8 @@ class Pytorch_to_OSS_MlFlow_ModelConverter(ModelConverter, PyTorch_to_MlFlow_Mod
|
|||
self.add_model_signature()
|
||||
self.copy_finetune_config(self.ft_pytorch_model_path, self.mlflow_model_save_path)
|
||||
|
||||
# Temp fix for t5 text-classification
|
||||
if self.is_t5_text_classification_finetune(model.config.model_type):
|
||||
# Temp fix for t5 text-classification, translation, summarization
|
||||
if self.is_t5_finetune(model.config.model_type):
|
||||
self.remove_unwanted_packages(self.mlflow_model_save_path)
|
||||
|
||||
logger.info("Saved MLFlow model using OSS flavour.")
|
||||
|
|
Загрузка…
Ссылка в новой задаче