Merge branch 'main' of https://github.com/Azure/mlops-project-template into dev
This commit is contained in:
Коммит
84e5c1fe36
|
@ -17,4 +17,4 @@ pytest==7.1.2
|
|||
pytest-cov==2.12.1
|
||||
|
||||
# Fix: force protobuf downgrade to avoid exception
|
||||
protobuf==3.20.1
|
||||
protobuf==3.20.2
|
||||
|
|
|
@ -101,22 +101,22 @@ command: >-
|
|||
python train.py
|
||||
--train_images ${{inputs.train_images}}
|
||||
--valid_images ${{inputs.valid_images}}
|
||||
[--batch_size ${{inputs.batch_size}}]
|
||||
[--num_workers ${{inputs.num_workers}}]
|
||||
[--prefetch_factor ${{inputs.prefetch_factor}}]
|
||||
[--persistent_workers ${{inputs.persistent_workers}}]
|
||||
[--pin_memory ${{inputs.pin_memory}}]
|
||||
[--non_blocking ${{inputs.non_blocking}}]
|
||||
[--model_arch ${{inputs.model_arch}}]
|
||||
[--model_arch_pretrained ${{inputs.model_arch_pretrained}}]
|
||||
[--num_epochs ${{inputs.num_epochs}}]
|
||||
[--learning_rate ${{inputs.learning_rate}}]
|
||||
[--momentum ${{inputs.momentum}}]
|
||||
$[[--batch_size ${{inputs.batch_size}}]]
|
||||
$[[--num_workers ${{inputs.num_workers}}]]
|
||||
$[[--prefetch_factor ${{inputs.prefetch_factor}}]]
|
||||
$[[--persistent_workers ${{inputs.persistent_workers}}]]
|
||||
$[[--pin_memory ${{inputs.pin_memory}}]]
|
||||
$[[--non_blocking ${{inputs.non_blocking}}]]
|
||||
$[[--model_arch ${{inputs.model_arch}}]]
|
||||
$[[--model_arch_pretrained ${{inputs.model_arch_pretrained}}]]
|
||||
$[[--num_epochs ${{inputs.num_epochs}}]]
|
||||
$[[--learning_rate ${{inputs.learning_rate}}]]
|
||||
$[[--momentum ${{inputs.momentum}}]]
|
||||
--model_output ${{outputs.trained_model}}
|
||||
--checkpoints ${{outputs.checkpoints}}
|
||||
[--register_model_as ${{inputs.register_model_as}}]
|
||||
$[[--register_model_as ${{inputs.register_model_as}}]]
|
||||
--enable_profiling ${{inputs.enable_profiling}}
|
||||
[--multiprocessing_sharing_strategy ${{inputs.multiprocessing_sharing_strategy}}]
|
||||
$[[--multiprocessing_sharing_strategy ${{inputs.multiprocessing_sharing_strategy}}]]
|
||||
distribution:
|
||||
# NOTE: using type:pytorch will use all the right env variables for pytorch init_process_group
|
||||
type: pytorch
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
data "azurerm_client_config" "current" {}
|
||||
|
||||
resource "azurerm_kusto_cluster" "cluster" {
|
||||
name = "adx${var.prefix}${var.postfix}${var.env}"
|
||||
location = var.location
|
||||
resource_group_name = var.rg_name
|
||||
streaming_ingestion_enabled = true
|
||||
language_extensions = ["PYTHON"]
|
||||
count = var.enable_monitoring ? 1 : 0
|
||||
|
||||
sku {
|
||||
name = "Standard_D11_v2"
|
||||
capacity = 2
|
||||
}
|
||||
tags = var.tags
|
||||
}
|
||||
|
||||
resource "azurerm_kusto_database" "database" {
|
||||
name = "mlmonitoring"
|
||||
resource_group_name = var.rg_name
|
||||
location = var.location
|
||||
cluster_name = azurerm_kusto_cluster.cluster[0].name
|
||||
count = var.enable_monitoring ? 1 : 0
|
||||
}
|
||||
|
||||
resource "azurerm_key_vault_secret" "SP_ID" {
|
||||
name = "kvmonitoringspid"
|
||||
value = data.azurerm_client_config.current.client_id
|
||||
key_vault_id = var.key_vault_id
|
||||
count = var.enable_monitoring ? 1 : 0
|
||||
}
|
||||
|
||||
resource "azurerm_key_vault_secret" "SP_KEY" {
|
||||
name = "kvmonitoringspkey"
|
||||
value = var.client_secret
|
||||
key_vault_id = var.key_vault_id
|
||||
count = var.enable_monitoring ? 1 : 0
|
||||
}
|
||||
|
||||
resource "azurerm_key_vault_secret" "SP_TENANT_ID" {
|
||||
name = "kvmonitoringadxtenantid"
|
||||
value = data.azurerm_client_config.current.tenant_id
|
||||
key_vault_id = var.key_vault_id
|
||||
count = var.enable_monitoring ? 1 : 0
|
||||
}
|
||||
|
||||
resource "azurerm_key_vault_secret" "ADX_URI" {
|
||||
name = "kvmonitoringadxuri"
|
||||
value = azurerm_kusto_cluster.cluster[0].uri
|
||||
key_vault_id = var.key_vault_id
|
||||
count = var.enable_monitoring ? 1 : 0
|
||||
}
|
||||
|
||||
resource "azurerm_key_vault_secret" "ADX_DB" {
|
||||
name = "kvmonitoringadxdb"
|
||||
value = azurerm_kusto_database.database[0].name
|
||||
key_vault_id = var.key_vault_id
|
||||
count = var.enable_monitoring ? 1 : 0
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
variable "rg_name" {
|
||||
type = string
|
||||
description = "Resource group name"
|
||||
}
|
||||
|
||||
variable "location" {
|
||||
type = string
|
||||
description = "Location of the resource group"
|
||||
}
|
||||
|
||||
variable "tags" {
|
||||
type = map(string)
|
||||
default = {}
|
||||
description = "A mapping of tags which should be assigned to the deployed resource"
|
||||
}
|
||||
|
||||
variable "prefix" {
|
||||
type = string
|
||||
description = "Prefix for the module name"
|
||||
}
|
||||
|
||||
variable "postfix" {
|
||||
type = string
|
||||
description = "Postfix for the module name"
|
||||
}
|
||||
|
||||
variable "env" {
|
||||
type = string
|
||||
description = "Environment prefix"
|
||||
}
|
||||
|
||||
variable "key_vault_id" {
|
||||
type = string
|
||||
description = "The ID of the Key Vault linked to AML workspace"
|
||||
}
|
||||
|
||||
variable "enable_monitoring" {
|
||||
description = "Variable to enable or disable AML compute cluster"
|
||||
default = false
|
||||
}
|
||||
|
||||
variable "client_secret" {
|
||||
description = "client secret"
|
||||
default = false
|
||||
}
|
|
@ -10,6 +10,7 @@ from transformers import (
|
|||
HfArgumentParser,
|
||||
IntervalStrategy,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
import torch
|
||||
import nltk
|
||||
|
@ -108,6 +109,21 @@ def compute_metrics(eval_preds, tokenizer, metric):
|
|||
return result
|
||||
|
||||
|
||||
class CustomCallback(TrainerCallback):
|
||||
"""A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
|
||||
|
||||
This is a hotfix for the issue raised here:
|
||||
https://github.com/huggingface/transformers/issues/18870
|
||||
"""
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
metrics = {}
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)):
|
||||
metrics[k] = v
|
||||
mlflow.log_metrics(metrics=metrics, step=state.global_step)
|
||||
|
||||
|
||||
def main():
|
||||
# Setup logging
|
||||
logger = logging.getLogger()
|
||||
|
@ -119,6 +135,9 @@ def main():
|
|||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# initialize the mlflow session
|
||||
mlflow.start_run()
|
||||
|
||||
parser = HfArgumentParser((ModelArgs, DataArgs, Seq2SeqTrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
logger.info(f"Running with arguments: {model_args}, {data_args}, {training_args}")
|
||||
|
@ -188,7 +207,7 @@ def main():
|
|||
training_args.save_strategy = "epoch"
|
||||
training_args.evaluation_strategy = IntervalStrategy.EPOCH
|
||||
training_args.predict_with_generate = True
|
||||
training_args.report_to = ["mlflow"]
|
||||
training_args.report_to = [] # use our own callback
|
||||
logger.info(f"training args: {training_args}")
|
||||
|
||||
# Initialize our Trainer
|
||||
|
@ -200,6 +219,7 @@ def main():
|
|||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=lambda preds : compute_metrics(preds, tokenizer, metric),
|
||||
callbacks=[CustomCallback]
|
||||
)
|
||||
|
||||
# Start the actual training (to include evaluation use --do-eval)
|
||||
|
|
Загрузка…
Ссылка в новой задаче