# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- import argparse import json import logging import sys from onnxruntime import __version__ as OrtVersion from packaging import version from olive.workflows import run as olive_run SUPPORTED_WORKFLOWS = { "cpu": [ ["conversion_merged", "transformers_optimization_fp32"], ["conversion_merged", "transformers_optimization_fp32", "onnx_dynamic_quant_int8"], ["conversion_merged", "transformers_optimization_fp32", "blockwise_quant_int4"], ], "gpu": [ ["conversion_merged", "transformers_optimization_fp16"], ["conversion_merged", "transformers_optimization_fp16", "blockwise_quant_int4"], ["gptq_quant_int4", "conversion_merged", "transformers_optimization_fp32"], ["gptq_quant_int4", "conversion_merged", "transformers_optimization_fp16"], ], } DEVICE_TO_EP = { "cpu": "CPUExecutionProvider", "gpu": "CUDAExecutionProvider", } def get_args(raw_args): parser = argparse.ArgumentParser(description="Llama2 optimization") parser.add_argument( "--model_name", type=str, default="meta-llama/Llama-2-7b-hf", help="Model name, currently only supports llama2 7B/13B", ) parser.add_argument("--gpu", action="store_true", required=False, help="Whether to use gpu for optimization.") parser.add_argument( "--use_gqa", action="store_true", required=False, help="Whether to use GQA(grouped query attention) instead of MHA(multi-head attention). Only supported on gpu.", ) parser.add_argument( "--use_gptq", action="store_true", required=False, help="Whether to use GPTQ quantization instead of RTN quantization. Only supported on gpu.", ) parser.add_argument( "--only_config", action="store_true", required=False, help="Whether to only dump the config file without running the optimization.", ) parser.add_argument( "--remote_config", type=str, required=False, help="Path to the azureml config file. If provided, the config file will be used to create the client.", ) parser.add_argument( "--cloud_cache", type=str, required=False, help="Whether to use cloud cache for optimization.", ) parser.add_argument( "--qlora", action="store_true", required=False, help="Whether to use qlora for optimization. Only supported on gpu.", ) parser.add_argument("--tempdir", type=str, help="Root directory for tempfile directories and files", required=False) return parser.parse_args(raw_args) def main(raw_args=None): if version.parse(OrtVersion) < version.parse("1.16.2"): raise ValueError("Please use onnxruntime>=1.16.2 for llama2 optimization") args = get_args(raw_args) if args.use_gqa and not args.gpu: raise ValueError("GQA is only supported on gpu.") if args.qlora: template_json, config_name = get_qlora_config() else: template_json, config_name = get_general_config(args) if args.remote_config: with open(args.remote_config) as f: remote_config = json.load(f) template_json["azureml_client"] = { "subscription_id": get_valid_config(remote_config, "subscription_id"), "resource_group": get_valid_config(remote_config, "resource_group"), "workspace_name": get_valid_config(remote_config, "workspace_name"), "keyvault_name": get_valid_config(remote_config, "keyvault_name"), } template_json["systems"]["aml_system"] = { "type": "AzureML", "accelerators": [{"device": "GPU", "execution_providers": ["CUDAExecutionProvider"]}], "aml_compute": get_valid_config(remote_config, "compute"), "aml_docker_config": { "base_image": "mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.8-cudnn8-ubuntu22.04", "conda_file_path": "conda_gpu.yaml", }, "hf_token": True, } template_json["workflow_host"] = "aml_system" if args.cloud_cache: with open(args.cloud_cache) as f: cloud_cache_config = json.load(f) template_json["cloud_cache_config"] = { "account_url": get_valid_config(cloud_cache_config, "account_url"), "container_name": get_valid_config(cloud_cache_config, "container_name"), "upload_to_cloud": get_valid_config(cloud_cache_config, "upload_to_cloud", True), } # dump config with open(f"{config_name}.json", "w") as f: json.dump(template_json, f, indent=4) if not args.only_config: olive_run(template_json, tempdir=args.tempdir) # pylint: disable=not-callable def get_valid_config(config, key, default=None): if key in config: return config[key] if default is not None: return default raise ValueError(f"Key {key} is required in the config file.") def get_qlora_config(): with open("llama2_qlora.json") as f: template_json = json.load(f) return template_json, "llama2_gpu_qlora" def get_general_config(args): with open("llama2_template.json") as f: template_json = json.load(f) model_name = args.model_name # update model name template_json_str = json.dumps(template_json) template_json_str = template_json_str.replace("", model_name) template_json = json.loads(template_json_str) # update configs device = "gpu" if args.gpu else "cpu" gqa = "gqa" if args.use_gqa else "mha" config_name = f"llama2_{device}_{gqa}" # add pass flows if not args.use_gptq: template_json["pass_flows"] = [flow for flow in SUPPORTED_WORKFLOWS[device] if "gptq" not in flow[0]] else: template_json["pass_flows"] = [flow for flow in SUPPORTED_WORKFLOWS[device] if "gptq" in flow[0]] auto_gptq_logger = logging.getLogger("auto_gptq") auto_gptq_logger.addHandler(logging.StreamHandler(sys.stdout)) auto_gptq_logger.setLevel(logging.INFO) # remove unused passes and set gqa related configs used_passes = {pass_name for pass_flow in SUPPORTED_WORKFLOWS[device] for pass_name in pass_flow} for pass_name in list(template_json["passes"].keys()): if pass_name not in used_passes: del template_json["passes"][pass_name] continue if not args.use_gqa and template_json["passes"][pass_name].get("evaluator", None) == "gqa_evaluator": # remove gqa evaluator if not using gqa del template_json["passes"][pass_name]["evaluator"] if not args.use_gqa and template_json["passes"][pass_name].get("use_gqa", False): # set use_gqa to False if not using gqa template_json["passes"][pass_name]["use_gqa"] = False if not args.use_gqa: del template_json["evaluators"]["gqa_evaluator"] template_json["systems"]["local_system"]["accelerators"][0]["device"] = device template_json["systems"]["local_system"]["accelerators"][0]["execution_providers"] = [DEVICE_TO_EP[device]] template_json["output_dir"] = f"models/{config_name}/{model_name}" return template_json, config_name if __name__ == "__main__": main()