Integrate MLflow; allow multiple sweeps (#90)

* MLFlow integration working, threading issue solved, API route conflict with multiple instantiation of widget fixed

* Add parameters use_ml_flow and ml_flow_run_name to control MLFlow

* Update readme with MLflow info

* Set use_ml_flow=False by default. Add MLflow to requirements.txt.

Co-authored-by: Nicholas King <v-nicki@microsoft.com>
This commit is contained in:
Nicholas King 2020-12-17 10:22:47 -08:00 коммит произвёл GitHub
Родитель 96a92d4431
Коммит 9765a40111
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
9 изменённых файлов: 111 добавлений и 40 удалений

2
.gitignore поставляемый
Просмотреть файл

@ -120,6 +120,8 @@ examples/sweeps-adult/*state
examples/sweeps-cifar10/*.json
examples/sweeps-cifar10/*state
mlruns/
tests/sweeps/*.json
tests/sweeps/*state

Просмотреть файл

@ -34,6 +34,11 @@ to see how the `backwardcompatibilityml` module is used.
To demo the widget, open the notebook `compatibility-analysis.ipynb`.
# MLflow
Compatibility sweeps are automatically logged with [MLflow](https://mlflow.org/). MLflow runs are logged in a folder named `mlruns` in the same directory as the notebook.
To view the MLflow dashboard, start the MLflow server by running `mlflow server --port 5200 --backend-store-uri ./mlruns`. Then, open the MLflow UI
in your browser by navigating to `localhost:5200`.
# Tests
To run tests, make sure that you are in the project root folder and do:

Просмотреть файл

@ -3,8 +3,9 @@
import copy
import json
import torch
import mlflow
import numpy as np
import torch
import backwardcompatibilityml.scores as scores
from backwardcompatibilityml.metrics import model_accuracy
@ -719,7 +720,9 @@ def compatibility_sweep(sweeps_folder_path, number_of_epochs, h1, h2,
new_error_loss_kwargs=None,
strict_imitation_loss_kwargs=None,
get_instance_metadata=None,
device="cpu"):
device="cpu",
use_ml_flow=False,
ml_flow_run_name="compatibility_sweep"):
"""
This function trains a new model using the backward compatibility loss function
BCNLLLoss with respect to an existing model. It does this for each value of
@ -766,7 +769,17 @@ def compatibility_sweep(sweeps_folder_path, number_of_epochs, h1, h2,
value is "cpu". But in case your models reside on the GPU, make sure
to set this to "cuda". This makes sure that the input and target
tensors are transferred to the GPU during training.
use_ml_flow: A boolean flag controlling whether or not to log the sweep
with MLFlow. If true, an MLFlow run will be created with the name
specified by ml_flow_run_name.
ml_flow_run_name: A string that configures the name of the MLFlow run.
"""
if use_ml_flow:
mlflow.start_run(run_name=ml_flow_run_name)
mlflow.log_param('lambda_c_stepsize', lambda_c_stepsize)
mlflow.log_param('batch_size_train', batch_size_train)
mlflow.log_param('batch_size_test', batch_size_test)
h1.eval()
number_of_trainings = 4 * len(np.arange(0.0, 1.0 + (lambda_c_stepsize / 2), lambda_c_stepsize))
if percent_complete_queue is not None:
@ -780,7 +793,9 @@ def compatibility_sweep(sweeps_folder_path, number_of_epochs, h1, h2,
sweep_summary_data = []
datapoint_index = 0
run_step = 0
for lambda_c in np.arange(0.0, 1.0 + (lambda_c_stepsize / 2), lambda_c_stepsize):
run_step += 1
h2_new_error = copy.deepcopy(h2)
train_new_error(
h1, h2_new_error, number_of_epochs,
@ -811,6 +826,11 @@ def compatibility_sweep(sweeps_folder_path, number_of_epochs, h1, h2,
"btc": training_set_performance_and_compatibility["btc"],
"bec": training_set_performance_and_compatibility["bec"]
})
if use_ml_flow:
mlflow.log_metric(f"lambda_c", lambda_c, step=run_step)
mlflow.log_metric(f"new_error_training_performance", training_set_performance_and_compatibility["h2_performance"], step=run_step)
mlflow.log_metric(f"new_error_training_btc", training_set_performance_and_compatibility["btc"], step=run_step)
mlflow.log_metric(f"new_error_training_bec", training_set_performance_and_compatibility["bec"], step=run_step)
training_evaluation_data = json.dumps(training_set_performance_and_compatibility)
training_evaluation_data_file = open(f"{sweeps_folder_path}/{datapoint_index}-evaluation-data.json", "w")
training_evaluation_data_file.write(training_evaluation_data)
@ -837,6 +857,10 @@ def compatibility_sweep(sweeps_folder_path, number_of_epochs, h1, h2,
"btc": testing_set_performance_and_compatibility["btc"],
"bec": testing_set_performance_and_compatibility["bec"]
})
if use_ml_flow:
mlflow.log_metric(f"new_error_testing_performance", testing_set_performance_and_compatibility["h2_performance"], step=run_step)
mlflow.log_metric(f"new_error_testing_btc", testing_set_performance_and_compatibility["btc"], step=run_step)
mlflow.log_metric(f"new_error_testing_bec", testing_set_performance_and_compatibility["bec"], step=run_step)
testing_evaluation_data = json.dumps(testing_set_performance_and_compatibility)
testing_evaluation_data_file = open(f"{sweeps_folder_path}/{datapoint_index}-evaluation-data.json", "w")
testing_evaluation_data_file.write(testing_evaluation_data)
@ -873,6 +897,10 @@ def compatibility_sweep(sweeps_folder_path, number_of_epochs, h1, h2,
"btc": training_set_performance_and_compatibility["btc"],
"bec": training_set_performance_and_compatibility["bec"]
})
if use_ml_flow:
mlflow.log_metric(f"strict_imitation_training_performance", training_set_performance_and_compatibility["h2_performance"], step=run_step)
mlflow.log_metric(f"strict_imitation_training_btc", training_set_performance_and_compatibility["btc"], step=run_step)
mlflow.log_metric(f"strict_imitation_training_bec", training_set_performance_and_compatibility["bec"], step=run_step)
training_evaluation_data = json.dumps(training_set_performance_and_compatibility)
training_evaluation_data_file = open(f"{sweeps_folder_path}/{datapoint_index}-evaluation-data.json", "w")
training_evaluation_data_file.write(training_evaluation_data)
@ -899,6 +927,10 @@ def compatibility_sweep(sweeps_folder_path, number_of_epochs, h1, h2,
"btc": testing_set_performance_and_compatibility["btc"],
"bec": testing_set_performance_and_compatibility["bec"]
})
if use_ml_flow:
mlflow.log_metric(f"strict_imitation_testing_performance", testing_set_performance_and_compatibility["h2_performance"], step=run_step)
mlflow.log_metric(f"strict_imitation_testing_btc", testing_set_performance_and_compatibility["btc"], step=run_step)
mlflow.log_metric(f"strict_imitation_testing_bec", testing_set_performance_and_compatibility["bec"], step=run_step)
testing_evaluation_data = json.dumps(testing_set_performance_and_compatibility)
testing_evaluation_data_file = open(f"{sweeps_folder_path}/{datapoint_index}-evaluation-data.json", "w")
testing_evaluation_data_file.write(testing_evaluation_data)
@ -917,3 +949,5 @@ def compatibility_sweep(sweeps_folder_path, number_of_epochs, h1, h2,
sweep_summary_data_file = open(f"{sweeps_folder_path}/sweep_summary.json", "w")
sweep_summary_data_file.write(sweep_summary_data)
sweep_summary_data_file.close()
if use_ml_flow:
mlflow.end_run()

Просмотреть файл

@ -6,6 +6,7 @@ import json
import threading
import io
import numpy as np
import mlflow
from flask import send_file
from PIL import Image
from queue import Queue
@ -65,6 +66,10 @@ class SweepManager(object):
value is "cpu". But in case your models reside on the GPU, make sure
to set this to "cuda". This makes sure that the input and target
tensors are transferred to the GPU during training.
use_ml_flow: A boolean flag controlling whether or not to log the sweep
with MLFlow. If true, an MLFlow run will be created with the name
specified by ml_flow_run_name.
ml_flow_run_name: A string that configures the name of the MLFlow run.
"""
def __init__(self, folder_name, number_of_epochs, h1, h2, training_set, test_set,
@ -76,7 +81,9 @@ class SweepManager(object):
performance_metric=model_accuracy,
get_instance_image_by_id=None,
get_instance_metadata=None,
device="cpu"):
device="cpu",
use_ml_flow=False,
ml_flow_run_name="compatibility_sweep"):
self.folder_name = folder_name
self.number_of_epochs = number_of_epochs
self.h1 = h1
@ -96,28 +103,42 @@ class SweepManager(object):
self.get_instance_image_by_id = get_instance_image_by_id
self.get_instance_metadata = get_instance_metadata
self.device = device
self.use_ml_flow = use_ml_flow
self.ml_flow_run_name = ml_flow_run_name
self.last_sweep_status = 0.0
self.percent_complete_queue = Queue()
self.sweep_thread = None
def start_sweep(self):
if self.is_running():
return
self.percent_complete_queue = Queue()
self.last_sweep_status = 0.0
self.sweep_thread = threading.Thread(
target=training.compatibility_sweep,
args=(self.folder_name, self.number_of_epochs, self.h1, self.h2,
self.training_set, self.test_set,
self.batch_size_train, self.batch_size_test,
self.OptimizerClass, self.optimizer_kwargs,
self.NewErrorLossClass, self.StrictImitationLossClass,
self.performance_metric,),
self.training_set, self.test_set,
self.batch_size_train, self.batch_size_test,
self.OptimizerClass, self.optimizer_kwargs,
self.NewErrorLossClass, self.StrictImitationLossClass,
self.performance_metric,),
kwargs={
"lambda_c_stepsize": self.lambda_c_stepsize,
"percent_complete_queue": self.percent_complete_queue,
"new_error_loss_kwargs": self.new_error_loss_kwargs,
"strict_imitation_loss_kwargs": self.strict_imitation_loss_kwargs,
"get_instance_metadata": self.get_instance_metadata,
"device": self.device
"device": self.device,
"use_ml_flow": self.use_ml_flow,
"ml_flow_run_name": self.ml_flow_run_name
})
def start_sweep(self):
self.sweep_thread.start()
def is_running(self):
if not self.sweep_thread:
return False
return self.sweep_thread.is_alive()
def start_sweep_synchronous(self):
training.compatibility_sweep(
self.folder_name, self.number_of_epochs, self.h1, self.h2, self.training_set, self.test_set,

Просмотреть файл

@ -108,7 +108,7 @@ def init_app_routes(app, sweep_manager):
def start_sweep():
sweep_manager.start_sweep()
return {
"running": sweep_manager.sweep_thread.is_alive(),
"running": sweep_manager.is_running(),
"percent_complete": sweep_manager.get_sweep_status()
}
@ -116,7 +116,7 @@ def init_app_routes(app, sweep_manager):
@http.no_cache
def get_sweep_status():
return {
"running": sweep_manager.sweep_thread.is_alive(),
"running": sweep_manager.is_running(),
"percent_complete": sweep_manager.get_sweep_status()
}
@ -208,6 +208,10 @@ class CompatibilityAnalysis(object):
value is "cpu". But in case your models reside on the GPU, make sure
to set this to "cuda". This makes sure that the input and target
tensors are transferred to the GPU during training.
use_ml_flow: A boolean flag controlling whether or not to log the sweep
with MLFlow. If true, an MLFlow run will be created with the name
specified by ml_flow_run_name.
ml_flow_run_name: A string that configures the name of the MLFlow run.
"""
def __init__(self, folder_name, number_of_epochs, h1, h2, training_set, test_set,
@ -219,7 +223,9 @@ class CompatibilityAnalysis(object):
strict_imitation_loss_kwargs=None,
get_instance_image_by_id=None,
get_instance_metadata=None,
device="cpu"):
device="cpu",
use_ml_flow=False,
ml_flow_run_name="compatibility_sweep"):
if OptimizerClass is None:
OptimizerClass = optim.SGD
@ -253,10 +259,21 @@ class CompatibilityAnalysis(object):
performance_metric=performance_metric,
get_instance_image_by_id=get_instance_image_by_id,
get_instance_metadata=get_instance_metadata,
device=device)
device=device,
use_ml_flow=use_ml_flow,
ml_flow_run_name=ml_flow_run_name)
self.flask_service = FlaskHelper(ip="0.0.0.0", port=port)
init_app_routes(FlaskHelper.app, self.sweep_manager)
app_has_routes = False
for route in FlaskHelper.app.url_map.iter_rules():
if route.endpoint == 'start_sweep':
app_has_routes = True
break
if app_has_routes:
FlaskHelper.app.logger.info("Routes already defined. Skipping route initialization.")
else:
FlaskHelper.app.logger.info("Initializing routes")
init_app_routes(FlaskHelper.app, self.sweep_manager)
api_service_environment = build_environment_params(self.flask_service.env)
api_service_environment["port"] = self.flask_service.port
html_string = render_widget_html(api_service_environment)

Просмотреть файл

@ -19,6 +19,8 @@ from flask import send_file
from PIL import Image
from rai_core_flask.flask_helper import FlaskHelper
use_ml_flow = True
ml_flow_run_name = "dev_app_sweep"
def breast_cancer_sweep():
folder_name = "tests/sweeps"
@ -95,7 +97,9 @@ def breast_cancer_sweep():
OptimizerClass=optim.SGD,
optimizer_kwargs={"lr": learning_rate, "momentum": momentum},
NewErrorLossClass=bcloss.BCCrossEntropyLoss,
StrictImitationLossClass=bcloss.StrictImitationCrossEntropyLoss)
StrictImitationLossClass=bcloss.StrictImitationCrossEntropyLoss,
use_ml_flow=use_ml_flow,
ml_flow_run_name=ml_flow_run_name)
def mnist_sweep():
@ -206,9 +210,11 @@ def mnist_sweep():
lambda_c_stepsize=0.25,
get_instance_image_by_id=get_instance_image,
get_instance_metadata=get_instance_label,
device="cuda")
device="cuda",
use_ml_flow=use_ml_flow,
ml_flow_run_name=ml_flow_run_name)
mnist_sweep()
breast_cancer_sweep()
app = FlaskHelper.app
app.logger.info('initialization complete')

Просмотреть файл

@ -10,3 +10,4 @@ tensorflow-datasets==4.1.0
tensorflow-estimator==2.3.0
tensorflow-metadata==0.25.0
Pillow==7.2.0
mlflow==1.12.1

Просмотреть файл

@ -6,10 +6,6 @@ import ReactDOM from "react-dom";
import * as d3 from "d3";
type SweepManagerState = {
sweepStatus: any
}
type SweepManagerProps = {
sweepStatus: any,
getSweepStatus: () => void,
@ -17,25 +13,15 @@ type SweepManagerProps = {
getTrainingAndTestingData: () => void
}
class SweepManager extends Component<SweepManagerProps, SweepManagerState> {
class SweepManager extends Component<SweepManagerProps> {
constructor(props) {
super(props);
this.state = {
sweepStatus: this.props.sweepStatus
};
this.pollSweepStatus = this.pollSweepStatus.bind(this);
this.startSweep = this.startSweep.bind(this);
}
timeoutVar: any = null
componentWillReceiveProps(nextProps) {
this.setState({
sweepStatus: nextProps.sweepStatus
});
}
timeoutVar: NodeJS.Timeout = null
componentWillUnmount() {
if (this.timeoutVar != null) {
@ -54,13 +40,13 @@ class SweepManager extends Component<SweepManagerProps, SweepManagerState> {
startSweep(evt) {
this.props.startSweep();
this.pollSweepStatus();
this.timeoutVar = setTimeout(this.pollSweepStatus, 500);
}
render() {
if (this.state.sweepStatus == null || !this.state.sweepStatus.running) {
if (this.timeoutVar != null && this.state.sweepStatus.percent_complete == 1.0) {
if (this.props.sweepStatus == null || !this.props.sweepStatus.running) {
if (this.timeoutVar != null && this.props.sweepStatus.percent_complete == 1.0) {
clearTimeout(this.timeoutVar);
this.timeoutVar = null;
this.props.getTrainingAndTestingData();
@ -77,7 +63,7 @@ class SweepManager extends Component<SweepManagerProps, SweepManagerState> {
<div className="table">
Sweep in progress
<div>
{Math.floor(this.state.sweepStatus.percent_complete * 100)} % complete
{Math.floor(this.props.sweepStatus.percent_complete * 100)} % complete
</div>
</div>
);

Просмотреть файл

@ -131,7 +131,6 @@ function getSweepStatus() {
function startSweep() {
return function(dispatch) {
dispatch(getSweepStatus());
makePostCall("api/v1/start_sweep", {});
}
}