This commit is contained in:
Maggie Mhanna 2022-10-05 12:37:31 +02:00
Родитель a20a5d3e15 342b879771
Коммит 84e5c1fe36
5 изменённых файлов: 139 добавлений и 15 удалений

Просмотреть файл

@ -17,4 +17,4 @@ pytest==7.1.2
pytest-cov==2.12.1
# Fix: force protobuf downgrade to avoid exception
protobuf==3.20.1
protobuf==3.20.2

Просмотреть файл

@ -101,22 +101,22 @@ command: >-
python train.py
--train_images ${{inputs.train_images}}
--valid_images ${{inputs.valid_images}}
[--batch_size ${{inputs.batch_size}}]
[--num_workers ${{inputs.num_workers}}]
[--prefetch_factor ${{inputs.prefetch_factor}}]
[--persistent_workers ${{inputs.persistent_workers}}]
[--pin_memory ${{inputs.pin_memory}}]
[--non_blocking ${{inputs.non_blocking}}]
[--model_arch ${{inputs.model_arch}}]
[--model_arch_pretrained ${{inputs.model_arch_pretrained}}]
[--num_epochs ${{inputs.num_epochs}}]
[--learning_rate ${{inputs.learning_rate}}]
[--momentum ${{inputs.momentum}}]
$[[--batch_size ${{inputs.batch_size}}]]
$[[--num_workers ${{inputs.num_workers}}]]
$[[--prefetch_factor ${{inputs.prefetch_factor}}]]
$[[--persistent_workers ${{inputs.persistent_workers}}]]
$[[--pin_memory ${{inputs.pin_memory}}]]
$[[--non_blocking ${{inputs.non_blocking}}]]
$[[--model_arch ${{inputs.model_arch}}]]
$[[--model_arch_pretrained ${{inputs.model_arch_pretrained}}]]
$[[--num_epochs ${{inputs.num_epochs}}]]
$[[--learning_rate ${{inputs.learning_rate}}]]
$[[--momentum ${{inputs.momentum}}]]
--model_output ${{outputs.trained_model}}
--checkpoints ${{outputs.checkpoints}}
[--register_model_as ${{inputs.register_model_as}}]
$[[--register_model_as ${{inputs.register_model_as}}]]
--enable_profiling ${{inputs.enable_profiling}}
[--multiprocessing_sharing_strategy ${{inputs.multiprocessing_sharing_strategy}}]
$[[--multiprocessing_sharing_strategy ${{inputs.multiprocessing_sharing_strategy}}]]
distribution:
# NOTE: using type:pytorch will use all the right env variables for pytorch init_process_group
type: pytorch

Просмотреть файл

@ -0,0 +1,59 @@
data "azurerm_client_config" "current" {}
resource "azurerm_kusto_cluster" "cluster" {
name = "adx${var.prefix}${var.postfix}${var.env}"
location = var.location
resource_group_name = var.rg_name
streaming_ingestion_enabled = true
language_extensions = ["PYTHON"]
count = var.enable_monitoring ? 1 : 0
sku {
name = "Standard_D11_v2"
capacity = 2
}
tags = var.tags
}
resource "azurerm_kusto_database" "database" {
name = "mlmonitoring"
resource_group_name = var.rg_name
location = var.location
cluster_name = azurerm_kusto_cluster.cluster[0].name
count = var.enable_monitoring ? 1 : 0
}
resource "azurerm_key_vault_secret" "SP_ID" {
name = "kvmonitoringspid"
value = data.azurerm_client_config.current.client_id
key_vault_id = var.key_vault_id
count = var.enable_monitoring ? 1 : 0
}
resource "azurerm_key_vault_secret" "SP_KEY" {
name = "kvmonitoringspkey"
value = var.client_secret
key_vault_id = var.key_vault_id
count = var.enable_monitoring ? 1 : 0
}
resource "azurerm_key_vault_secret" "SP_TENANT_ID" {
name = "kvmonitoringadxtenantid"
value = data.azurerm_client_config.current.tenant_id
key_vault_id = var.key_vault_id
count = var.enable_monitoring ? 1 : 0
}
resource "azurerm_key_vault_secret" "ADX_URI" {
name = "kvmonitoringadxuri"
value = azurerm_kusto_cluster.cluster[0].uri
key_vault_id = var.key_vault_id
count = var.enable_monitoring ? 1 : 0
}
resource "azurerm_key_vault_secret" "ADX_DB" {
name = "kvmonitoringadxdb"
value = azurerm_kusto_database.database[0].name
key_vault_id = var.key_vault_id
count = var.enable_monitoring ? 1 : 0
}

Просмотреть файл

@ -0,0 +1,45 @@
variable "rg_name" {
type = string
description = "Resource group name"
}
variable "location" {
type = string
description = "Location of the resource group"
}
variable "tags" {
type = map(string)
default = {}
description = "A mapping of tags which should be assigned to the deployed resource"
}
variable "prefix" {
type = string
description = "Prefix for the module name"
}
variable "postfix" {
type = string
description = "Postfix for the module name"
}
variable "env" {
type = string
description = "Environment prefix"
}
variable "key_vault_id" {
type = string
description = "The ID of the Key Vault linked to AML workspace"
}
variable "enable_monitoring" {
description = "Variable to enable or disable AML compute cluster"
default = false
}
variable "client_secret" {
description = "client secret"
default = false
}

Просмотреть файл

@ -10,6 +10,7 @@ from transformers import (
HfArgumentParser,
IntervalStrategy,
)
from transformers.trainer_callback import TrainerCallback
import torch
import nltk
@ -108,6 +109,21 @@ def compute_metrics(eval_preds, tokenizer, metric):
return result
class CustomCallback(TrainerCallback):
"""A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
This is a hotfix for the issue raised here:
https://github.com/huggingface/transformers/issues/18870
"""
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero:
metrics = {}
for k, v in logs.items():
if isinstance(v, (int, float)):
metrics[k] = v
mlflow.log_metrics(metrics=metrics, step=state.global_step)
def main():
# Setup logging
logger = logging.getLogger()
@ -119,6 +135,9 @@ def main():
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# initialize the mlflow session
mlflow.start_run()
parser = HfArgumentParser((ModelArgs, DataArgs, Seq2SeqTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
logger.info(f"Running with arguments: {model_args}, {data_args}, {training_args}")
@ -188,7 +207,7 @@ def main():
training_args.save_strategy = "epoch"
training_args.evaluation_strategy = IntervalStrategy.EPOCH
training_args.predict_with_generate = True
training_args.report_to = ["mlflow"]
training_args.report_to = [] # use our own callback
logger.info(f"training args: {training_args}")
# Initialize our Trainer
@ -200,6 +219,7 @@ def main():
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=lambda preds : compute_metrics(preds, tokenizer, metric),
callbacks=[CustomCallback]
)
# Start the actual training (to include evaluation use --do-eval)