Unify OS constants (#1183)
## Describe your changes Unify OS constants ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link
This commit is contained in:
Родитель
8dbf3e3e58
Коммит
3d3277b646
|
@ -6,6 +6,8 @@ import json
|
|||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
from olive.common.constants import OS
|
||||
|
||||
|
||||
def resolve_windows_config():
|
||||
|
||||
|
@ -18,5 +20,5 @@ def resolve_windows_config():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
resolve_windows_config()
|
||||
|
|
|
@ -7,6 +7,8 @@ import json
|
|||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
from olive.common.constants import OS
|
||||
|
||||
|
||||
def raw_qnn_config():
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
@ -16,20 +18,20 @@ def raw_qnn_config():
|
|||
|
||||
sys_platform = platform.system()
|
||||
|
||||
if sys_platform == "Linux":
|
||||
if sys_platform == OS.LINUX:
|
||||
raw_qnn_config["passes"]["qnn_context_binary"] = {
|
||||
"type": "QNNContextBinaryGenerator",
|
||||
"config": {"backend": "libQnnHtp.so"},
|
||||
}
|
||||
raw_qnn_config["pass_flows"].append(["converter", "build_model_lib", "qnn_context_binary"])
|
||||
raw_qnn_config["passes"]["build_model_lib"]["config"]["lib_targets"] = "x86_64-linux-clang"
|
||||
elif sys_platform == "Windows":
|
||||
elif sys_platform == OS.WINDOWS:
|
||||
raw_qnn_config["passes"]["build_model_lib"]["config"]["lib_targets"] = "x86_64-windows-msvc"
|
||||
|
||||
for metric_config in raw_qnn_config["evaluators"]["common_evaluator"]["metrics"]:
|
||||
if sys_platform == "Windows":
|
||||
if sys_platform == OS.WINDOWS:
|
||||
metric_config["user_config"]["inference_settings"]["qnn"]["backend"] = "QnnCpu"
|
||||
elif sys_platform == "Linux":
|
||||
elif sys_platform == OS.LINUX:
|
||||
metric_config["user_config"]["inference_settings"]["qnn"]["backend"] = "libQnnCpu"
|
||||
|
||||
with Path("raw_qnn_sdk_config.json").open("w") as f:
|
||||
|
|
|
@ -11,6 +11,7 @@ from pathlib import Path
|
|||
from onnxruntime import __version__ as OrtVersion
|
||||
from packaging import version
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.workflows import run as olive_run
|
||||
|
||||
# flake8: noqa: T201
|
||||
|
@ -30,25 +31,25 @@ SUPPORTED_INFERENCE_CONFIG = {
|
|||
# -1 to use CPU
|
||||
"device_id": -1,
|
||||
"use_fp16": False,
|
||||
"use_step": platform.system() == "Linux",
|
||||
"use_step": platform.system() == OS.LINUX,
|
||||
},
|
||||
"cpu_int4": {
|
||||
"use_buffer_share": False,
|
||||
"device_id": -1,
|
||||
"use_fp16": False,
|
||||
"use_step": platform.system() == "Linux",
|
||||
"use_step": platform.system() == OS.LINUX,
|
||||
},
|
||||
"cuda_fp16": {
|
||||
"use_buffer_share": False,
|
||||
"device_id": 0,
|
||||
"use_fp16": True,
|
||||
"use_step": platform.system() == "Linux",
|
||||
"use_step": platform.system() == OS.LINUX,
|
||||
},
|
||||
"cuda_int4": {
|
||||
"use_buffer_share": False,
|
||||
"device_id": 0,
|
||||
"use_fp16": True,
|
||||
"use_step": platform.system() == "Linux",
|
||||
"use_step": platform.system() == OS.LINUX,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -180,7 +181,7 @@ def main(raw_args=None):
|
|||
with open(json_file_template) as f:
|
||||
template_json = json.load(f)
|
||||
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
legacy_optimization_setting(template_json)
|
||||
|
||||
# add pass flows
|
||||
|
|
|
@ -9,6 +9,8 @@ from pathlib import Path
|
|||
import pytest
|
||||
from utils import check_output, patch_config
|
||||
|
||||
from olive.common.constants import OS
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup():
|
||||
|
@ -25,7 +27,7 @@ def setup():
|
|||
@pytest.mark.parametrize("system", ["docker_system"])
|
||||
@pytest.mark.parametrize("olive_json", ["bert_ptq_cpu.json"])
|
||||
def test_bert(search_algorithm, execution_order, system, olive_json):
|
||||
if system == "docker_system" and platform.system() == "Windows":
|
||||
if system == "docker_system" and platform.system() == OS.WINDOWS:
|
||||
pytest.skip("Skip Linux containers on Windows host test case.")
|
||||
|
||||
from olive.workflows import run as olive_run
|
||||
|
|
|
@ -9,6 +9,7 @@ from pathlib import Path
|
|||
import pytest
|
||||
from utils import check_output, download_azure_blob
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.common.utils import retry_func, run_subprocess
|
||||
from olive.logging import set_verbosity_debug
|
||||
|
||||
|
@ -20,9 +21,9 @@ class TestQnnToolkit:
|
|||
def setup(self, tmp_path):
|
||||
"""Download the qnn sdk."""
|
||||
blob, download_path = "", ""
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
blob, download_path = "qnn_sdk_windows.zip", "qnn_sdk_windows.zip"
|
||||
elif platform.system() == "Linux":
|
||||
elif platform.system() == OS.LINUX:
|
||||
blob, download_path = "qnn_sdk_linux.zip", "qnn_sdk_linux.zip"
|
||||
|
||||
download_azure_blob(
|
||||
|
@ -32,10 +33,10 @@ class TestQnnToolkit:
|
|||
)
|
||||
target_path = tmp_path / "qnn_sdk"
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
cmd = f"powershell Expand-Archive -Path {download_path} -DestinationPath {str(target_path)}"
|
||||
run_subprocess(cmd=cmd, check=True)
|
||||
elif platform.system() == "Linux":
|
||||
elif platform.system() == OS.LINUX:
|
||||
run_subprocess(cmd=f"unzip {download_path} -d {str(target_path)}", check=True)
|
||||
|
||||
os.environ["QNN_SDK_ROOT"] = str(target_path / "opt" / "qcom" / "aistack")
|
||||
|
@ -55,9 +56,9 @@ class TestQnnToolkit:
|
|||
)
|
||||
# install dependencies
|
||||
python_cmd = ""
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
python_cmd = str(Path(os.environ["QNN_SDK_ROOT"]) / "olive-pyenv" / "python.exe")
|
||||
elif platform.system() == "Linux":
|
||||
elif platform.system() == OS.LINUX:
|
||||
python_cmd = str(Path(os.environ["QNN_SDK_ROOT"]) / "olive-pyenv" / "bin" / "python")
|
||||
install_cmd = [
|
||||
python_cmd,
|
||||
|
@ -69,7 +70,7 @@ class TestQnnToolkit:
|
|||
packages = ["tensorflow==2.10.1", "numpy==1.23.5"]
|
||||
retry_func(run_subprocess, kwargs={"cmd": f"python -m pip install {' '.join(packages)}", "check": True})
|
||||
os.environ["PYTHONPATH"] = str(Path(os.environ["QNN_SDK_ROOT"]) / "lib" / "python")
|
||||
if platform.system() == "Linux":
|
||||
if platform.system() == OS.LINUX:
|
||||
os.environ["PATH"] = (
|
||||
str(Path(os.environ["QNN_SDK_ROOT"]) / "bin" / "x86_64-linux-clang")
|
||||
+ os.path.pathsep
|
||||
|
|
|
@ -9,6 +9,7 @@ from pathlib import Path
|
|||
import pytest
|
||||
from utils import check_output, download_azure_blob
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.common.utils import retry_func, run_subprocess
|
||||
from olive.logging import set_verbosity_debug
|
||||
|
||||
|
@ -20,9 +21,9 @@ class TestSnpeToolkit:
|
|||
def setup(self, tmp_path):
|
||||
"""Download the snpe sdk."""
|
||||
blob, download_path = "", ""
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
blob, download_path = "snpe_sdk_windows.zip", "snpe_sdk_windows.zip"
|
||||
elif platform.system() == "Linux":
|
||||
elif platform.system() == OS.LINUX:
|
||||
blob, download_path = "snpe_sdk_linux.zip", "snpe_sdk_linux.zip"
|
||||
|
||||
download_azure_blob(
|
||||
|
@ -32,10 +33,10 @@ class TestSnpeToolkit:
|
|||
)
|
||||
target_path = tmp_path / "snpe_sdk"
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
cmd = f"powershell Expand-Archive -Path {download_path} -DestinationPath {str(target_path)}"
|
||||
run_subprocess(cmd=cmd, check=True)
|
||||
elif platform.system() == "Linux":
|
||||
elif platform.system() == OS.LINUX:
|
||||
run_subprocess(cmd=f"unzip {download_path} -d {str(target_path)}", check=True)
|
||||
os.environ["SNPE_ROOT"] = str(target_path)
|
||||
|
||||
|
@ -54,9 +55,9 @@ class TestSnpeToolkit:
|
|||
)
|
||||
# install dependencies
|
||||
python_cmd = ""
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
python_cmd = str(Path(os.environ["SNPE_ROOT"]) / "olive-pyenv" / "python.exe")
|
||||
elif platform.system() == "Linux":
|
||||
elif platform.system() == OS.LINUX:
|
||||
python_cmd = str(Path(os.environ["SNPE_ROOT"]) / "olive-pyenv" / "bin" / "python")
|
||||
install_cmd = [
|
||||
python_cmd,
|
||||
|
@ -68,7 +69,7 @@ class TestSnpeToolkit:
|
|||
packages = ["tensorflow==2.10.1", "numpy==1.23.5"]
|
||||
retry_func(run_subprocess, kwargs={"cmd": f"python -m pip install {' '.join(packages)}", "check": True})
|
||||
os.environ["PYTHONPATH"] = str(Path(os.environ["SNPE_ROOT"]) / "lib" / "python")
|
||||
if platform.system() == "Linux":
|
||||
if platform.system() == OS.LINUX:
|
||||
os.environ["PATH"] = (
|
||||
str(Path(os.environ["SNPE_ROOT"]) / "bin" / "x86_64-linux-clang")
|
||||
+ os.path.pathsep
|
||||
|
|
|
@ -11,6 +11,8 @@ from pathlib import Path
|
|||
import pytest
|
||||
from utils import check_output
|
||||
|
||||
from olive.common.constants import OS
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup():
|
||||
|
@ -34,7 +36,7 @@ def setup():
|
|||
def test_whisper(device_precision):
|
||||
from olive.workflows import run as olive_run
|
||||
|
||||
if platform.system() == "Windows" and device_precision[1].startswith("inc_int8"):
|
||||
if platform.system() == OS.WINDOWS and device_precision[1].startswith("inc_int8"):
|
||||
pytest.skip("Skip test on Windows. neural-compressor import is hanging on Windows.")
|
||||
|
||||
device, precision = device_precision
|
||||
|
|
|
@ -6,6 +6,8 @@ import json
|
|||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
from olive.common.constants import OS
|
||||
|
||||
|
||||
def resolve_windows_config():
|
||||
|
||||
|
@ -18,5 +20,5 @@ def resolve_windows_config():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
resolve_windows_config()
|
||||
|
|
|
@ -2,5 +2,11 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from enum import Enum
|
||||
|
||||
DEFAULT_WORKFLOW_ID = "default_workflow"
|
||||
|
||||
|
||||
class OS(str, Enum):
|
||||
WINDOWS = "Windows"
|
||||
LINUX = "Linux"
|
||||
|
|
|
@ -18,6 +18,8 @@ import time
|
|||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from olive.common.constants import OS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -25,7 +27,7 @@ def run_subprocess(cmd, env=None, cwd=None, check=False):
|
|||
logger.debug("Running command: %s", cmd)
|
||||
|
||||
assert isinstance(cmd, (str, list)), f"cmd must be a string or a list, got {type(cmd)}."
|
||||
windows = platform.system() == "Windows"
|
||||
windows = platform.system() == OS.WINDOWS
|
||||
if isinstance(cmd, str):
|
||||
# In posix model, the cmd string will be handled with specific posix rules.
|
||||
# https://docs.python.org/3.8/library/shlex.html#parsing-rules
|
||||
|
|
|
@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Set, Union
|
|||
|
||||
import pkg_resources
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.common.utils import copy_dir, retry_func, run_subprocess
|
||||
from olive.engine.packaging.packaging_config import (
|
||||
AzureMLDeploymentPackagingConfig,
|
||||
|
@ -707,14 +708,14 @@ def _download_ort_extensions_package(use_ort_extensions: bool, download_path: st
|
|||
# Hardcode the nightly version number for now until we have a better way to identify nightly version
|
||||
if version.startswith("0.8.0."):
|
||||
system = platform.system()
|
||||
if system == "Windows":
|
||||
if system == OS.WINDOWS:
|
||||
download_command = (
|
||||
f"{sys.executable} -m pip download -i "
|
||||
"https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ "
|
||||
f"onnxruntime-extensions=={version} --no-deps -d {download_path}"
|
||||
)
|
||||
run_subprocess(download_command)
|
||||
elif system == "Linux":
|
||||
elif system == OS.LINUX:
|
||||
logger.warning(
|
||||
"ONNXRuntime-Extensions nightly package is not available for Linux. "
|
||||
"Skip packaging ONNXRuntime-Extensions package. Please manually install ONNXRuntime-Extensions."
|
||||
|
|
|
@ -7,6 +7,7 @@ import platform
|
|||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.constants import Framework, ModelFileFormat
|
||||
from olive.hardware.accelerator import Device
|
||||
from olive.model.config import IoConfig
|
||||
|
@ -55,10 +56,10 @@ class QNNModelHandler(OliveModelHandler):
|
|||
logger.debug(
|
||||
"QNNModelHandler: lib_targets is not provided, using default lib_targets x86_64-linux-clang"
|
||||
)
|
||||
if platform.system() == "Linux":
|
||||
if platform.system() == OS.LINUX:
|
||||
lib_targets = "x86_64-linux-clang"
|
||||
model_lib_suffix = ".so"
|
||||
elif platform.system() == "Windows":
|
||||
elif platform.system() == OS.WINDOWS:
|
||||
# might be different for arm devices
|
||||
lib_targets = "x64"
|
||||
model_lib_suffix = ".dll"
|
||||
|
|
|
@ -8,6 +8,7 @@ import platform
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.constants import ModelFileFormat
|
||||
from olive.hardware import AcceleratorSpec
|
||||
from olive.model import QNNModelHandler
|
||||
|
@ -26,7 +27,7 @@ class QNNContextBinaryGenerator(Pass):
|
|||
|
||||
@classmethod
|
||||
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
raise NotImplementedError("QNNContextBinaryGenerator is not supported on Windows.")
|
||||
|
||||
return {
|
||||
|
|
|
@ -7,6 +7,7 @@ import platform
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.constants import ModelFileFormat
|
||||
from olive.hardware import AcceleratorSpec
|
||||
from olive.model import ONNXModelHandler, PyTorchModelHandler, QNNModelHandler, TensorFlowModelHandler
|
||||
|
@ -82,7 +83,7 @@ class QNNConversion(Pass):
|
|||
converter_program = [f"qnn-{converter_platform}-converter"]
|
||||
|
||||
runner = QNNSDKRunner(use_dev_tools=True)
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
converter_program = [
|
||||
"python",
|
||||
str(Path(runner.sdk_env.sdk_root_path) / "bin" / runner.sdk_env.target_arch / converter_program[0]),
|
||||
|
|
|
@ -8,6 +8,7 @@ import platform
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.constants import ModelFileFormat
|
||||
from olive.hardware import AcceleratorSpec
|
||||
from olive.model import QNNModelHandler
|
||||
|
@ -53,7 +54,7 @@ class QNNModelLibGenerator(Pass):
|
|||
) -> QNNModelHandler:
|
||||
main_cmd = "qnn-model-lib-generator"
|
||||
runner = QNNSDKRunner(use_dev_tools=True)
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
main_cmd = "python " + str(
|
||||
Path(runner.sdk_env.sdk_root_path) / "bin" / runner.sdk_env.target_arch / main_cmd
|
||||
)
|
||||
|
|
|
@ -10,6 +10,7 @@ import shutil
|
|||
from importlib import resources
|
||||
from pathlib import Path
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.common.utils import run_subprocess
|
||||
from olive.platform_sdk.qualcomm.constants import SDKTargetDevice
|
||||
from olive.platform_sdk.qualcomm.qnn.env import QNNSDKEnv
|
||||
|
@ -31,14 +32,14 @@ def configure_dev(py_version: str, sdk: str):
|
|||
if sdk_arch not in (SDKTargetDevice.x86_64_linux, SDKTargetDevice.x86_64_windows):
|
||||
return
|
||||
|
||||
script_name = "create_python_env.sh" if platform.system() == "Linux" else "create_python_env.ps1"
|
||||
script_name = "create_python_env.sh" if platform.system() == OS.LINUX else "create_python_env.ps1"
|
||||
|
||||
logger.info("Configuring %s for %s with python %s...", sdk, sdk_arch, py_version)
|
||||
cmd = None
|
||||
with resources.path(resource_path, script_name) as create_python_env_path:
|
||||
if platform.system() == "Linux":
|
||||
if platform.system() == OS.LINUX:
|
||||
cmd = f"bash {create_python_env_path} -v {py_version} --sdk {sdk}"
|
||||
elif platform.system() == "Windows":
|
||||
elif platform.system() == OS.WINDOWS:
|
||||
cmd = f"powershell {create_python_env_path} {py_version} {sdk}"
|
||||
run_subprocess(cmd, check=True)
|
||||
logger.info("Done")
|
||||
|
|
|
@ -7,6 +7,7 @@ import os
|
|||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.platform_sdk.qualcomm.constants import SDKTargetDevice
|
||||
|
||||
|
||||
|
@ -37,14 +38,14 @@ class SDKEnv:
|
|||
"""Infer the target architecture from the SDK root path based on platform and processor."""
|
||||
system = platform.system()
|
||||
target_arch = None
|
||||
if system == "Linux":
|
||||
if system == OS.LINUX:
|
||||
machine = platform.machine()
|
||||
if machine == "x86_64":
|
||||
target_arch = SDKTargetDevice.x86_64_linux
|
||||
else:
|
||||
if fail_on_unsupported:
|
||||
raise ValueError(f"Unsupported machine {machine} on system {system}")
|
||||
elif system == "Windows":
|
||||
elif system == OS.WINDOWS:
|
||||
processor_identifier = os.environ.get("PROCESSOR_IDENTIFIER", "")
|
||||
if "ARM" in processor_identifier:
|
||||
target_arch = SDKTargetDevice.arm64x_windows
|
||||
|
@ -84,7 +85,7 @@ class SDKEnv:
|
|||
if platform.system() == "Linux":
|
||||
bin_path += delimiter + "/usr/bin"
|
||||
env["LD_LIBRARY_PATH"] = lib_path
|
||||
elif platform.system() == "Windows":
|
||||
elif platform.system() == OS.WINDOWS:
|
||||
bin_path += delimiter + lib_path
|
||||
|
||||
env["PATH"] = bin_path
|
||||
|
|
|
@ -7,6 +7,7 @@ import os
|
|||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.platform_sdk.qualcomm.env import SDKEnv
|
||||
|
||||
|
||||
|
@ -19,7 +20,7 @@ class QNNSDKEnv(SDKEnv):
|
|||
env = super().env
|
||||
sdk_root_path = self.sdk_root_path
|
||||
delimiter = os.path.pathsep
|
||||
python_env_parent_folder = "Scripts" if platform.system() == "Windows" else "bin"
|
||||
python_env_parent_folder = "Scripts" if platform.system() == OS.WINDOWS else "bin"
|
||||
python_env_bin_path = str(Path(f"{sdk_root_path}/olive-pyenv/{python_env_parent_folder}"))
|
||||
|
||||
env["PATH"] += delimiter + os.environ["PATH"]
|
||||
|
@ -31,7 +32,7 @@ class QNNSDKEnv(SDKEnv):
|
|||
" to add the missing file."
|
||||
)
|
||||
env["PATH"] = python_env_bin_path + delimiter + env["PATH"]
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
for k, v in os.environ.items():
|
||||
if k not in env:
|
||||
env[k] = v
|
||||
|
@ -41,7 +42,7 @@ class QNNSDKEnv(SDKEnv):
|
|||
def get_qnn_backend(self, backend_name):
|
||||
backend_path = Path(self.sdk_root_path) / "lib" / self.target_arch / backend_name
|
||||
backend_path = (
|
||||
backend_path.with_suffix(".dll") if platform.system() == "Windows" else backend_path.with_suffix(".so")
|
||||
backend_path.with_suffix(".dll") if platform.system() == OS.WINDOWS else backend_path.with_suffix(".so")
|
||||
)
|
||||
|
||||
if not backend_path.exists():
|
||||
|
|
|
@ -13,6 +13,7 @@ from typing import Dict, Optional
|
|||
import numpy as np
|
||||
|
||||
from olive.common.config_utils import ConfigBase
|
||||
from olive.common.constants import OS
|
||||
from olive.constants import ModelFileFormat
|
||||
from olive.platform_sdk.qualcomm.runner import QNNSDKRunner
|
||||
from olive.platform_sdk.qualcomm.utils.input_list import get_input_ids
|
||||
|
@ -128,7 +129,7 @@ class QNNInferenceSession:
|
|||
# copy the raw file to the workspace and rename it
|
||||
if output_dir is not None:
|
||||
output_file = output_dir / f"{input_id}.raw"
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
output_file = output_dir / f"{input_id}.raw".replace(":", "_")
|
||||
raw_file.rename(output_file)
|
||||
result_files.append((input_id, output_file))
|
||||
|
|
|
@ -10,6 +10,7 @@ import time
|
|||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.common.utils import run_subprocess
|
||||
from olive.platform_sdk.qualcomm.qnn.env import QNNSDKEnv
|
||||
from olive.platform_sdk.qualcomm.snpe.env import SNPESDKEnv
|
||||
|
@ -49,11 +50,11 @@ class SDKRunner:
|
|||
import platform
|
||||
|
||||
if isinstance(cmd, str):
|
||||
cmd_list = shlex.split(cmd, posix=(platform.system() != "Windows"))
|
||||
cmd_list = shlex.split(cmd, posix=(platform.system() != OS.WINDOWS))
|
||||
else:
|
||||
cmd_list = cmd
|
||||
|
||||
if platform.system() == "Windows" and cmd_list[0].startswith(("snpe-", "qnn-")):
|
||||
if platform.system() == OS.WINDOWS and cmd_list[0].startswith(("snpe-", "qnn-")):
|
||||
logger.debug("Resolving command %s on Windows.", cmd_list)
|
||||
cmd_dir = Path(self.sdk_env.sdk_root_path) / "bin" / self.sdk_env.target_arch
|
||||
cmd_name = cmd_list[0]
|
||||
|
|
|
@ -7,6 +7,7 @@ import os
|
|||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.platform_sdk.qualcomm.constants import SDKTargetDevice
|
||||
from olive.platform_sdk.qualcomm.env import SDKEnv
|
||||
|
||||
|
@ -21,7 +22,7 @@ class SNPESDKEnv(SDKEnv):
|
|||
target_arch = self.target_arch
|
||||
sdk_root_path = self.sdk_root_path
|
||||
delimiter = os.path.pathsep
|
||||
python_env_parent_folder = "" if platform.system() == "Windows" else "bin"
|
||||
python_env_parent_folder = "" if platform.system() == OS.WINDOWS else "bin"
|
||||
python_env_bin_path = str(Path(f"{sdk_root_path}/olive-pyenv/{python_env_parent_folder}"))
|
||||
|
||||
env["PATH"] += delimiter + os.environ["PATH"]
|
||||
|
@ -35,7 +36,7 @@ class SNPESDKEnv(SDKEnv):
|
|||
|
||||
env["PATH"] = python_env_bin_path + delimiter + env["PATH"]
|
||||
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
os_env = os.environ.copy()
|
||||
os_env.update(env)
|
||||
env = os_env
|
||||
|
|
|
@ -12,6 +12,7 @@ from typing import List
|
|||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.platform_sdk.qualcomm.runner import SNPESDKRunner as SNPERunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -191,7 +192,7 @@ def quantize_dlc(dlc_path: str, input_list: str, config: dict, output_file: str)
|
|||
extra_args: str = extra arguments to pass to the quantizer.
|
||||
"""
|
||||
quant_cmd = "snpe-dlc-quantize"
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
# snpe-dlc-quant is the Windows version of the quantizer tool
|
||||
# and it does not support the --enable_htp flag
|
||||
quant_cmd = "snpe-dlc-quant"
|
||||
|
@ -199,7 +200,7 @@ def quantize_dlc(dlc_path: str, input_list: str, config: dict, output_file: str)
|
|||
if config["use_enhanced_quantizer"]:
|
||||
cmd += " --use_enhanced_quantizer"
|
||||
if config["enable_htp"]:
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
logger.warning("--enable_htp is not supported on Windows")
|
||||
else:
|
||||
cmd += " --enable_htp"
|
||||
|
|
|
@ -14,6 +14,7 @@ import numpy as np
|
|||
|
||||
import olive.platform_sdk.qualcomm.snpe.utils.adb as adb_utils
|
||||
from olive.common.config_utils import validate_enum
|
||||
from olive.common.constants import OS
|
||||
from olive.platform_sdk.qualcomm.constants import PerfProfile, ProfilingLevel, SNPEDevice
|
||||
from olive.platform_sdk.qualcomm.runner import SNPESDKRunner as SNPERunner
|
||||
from olive.platform_sdk.qualcomm.utils.input_list import get_input_ids, resolve_input_list
|
||||
|
@ -97,7 +98,7 @@ def _snpe_net_run_adb(
|
|||
adb_utils.run_snpe_adb_command(cmd, android_target, push_snpe, runs=runs, sleep=sleep)
|
||||
|
||||
# replace ":" in output filenames with "_" if on windows before pulling
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
rename_cmd = (
|
||||
f"cd {target_output_dir} &&"
|
||||
" find -name '*:*' -exec sh -c 'for x; do mv $x $(echo $x | sed \"s/:/_/g\"); done' _ {} +"
|
||||
|
@ -214,9 +215,9 @@ def snpe_net_run(
|
|||
|
||||
# get the delimiter for the output files
|
||||
delimiter = None
|
||||
if platform.system() == "Linux":
|
||||
if platform.system() == OS.LINUX:
|
||||
delimiter = ":"
|
||||
elif platform.system() == "Windows":
|
||||
elif platform.system() == OS.WINDOWS:
|
||||
delimiter = "_"
|
||||
|
||||
# dictionary to store the results as numpy arrays
|
||||
|
@ -242,7 +243,7 @@ def snpe_net_run(
|
|||
if not (member / output_file_name).exists():
|
||||
# `:0` is already in the output name or source model was not tensorflow
|
||||
output_file_name = f"{output_name}.raw"
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
# replace ":" with "_" in the file name.
|
||||
output_file_name = output_file_name.replace(":", "_")
|
||||
raw_file = member / output_file_name
|
||||
|
@ -250,7 +251,7 @@ def snpe_net_run(
|
|||
# copy the raw file to the workspace and rename it
|
||||
if output_dir is not None:
|
||||
output_file = output_dir / f"{input_id}.{output_name}.raw"
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
# replace ":" with "_" in the file name.
|
||||
output_file = output_dir / f"{input_id}.{output_name}.raw".replace(":", "_")
|
||||
if len(output_names) == 1:
|
||||
|
|
|
@ -9,6 +9,7 @@ import time
|
|||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.common.utils import run_subprocess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -64,9 +65,9 @@ def run_adb_command(
|
|||
|
||||
# platform specific shell command
|
||||
if shell_cmd:
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
cmd = f"shell {cmd}"
|
||||
elif platform.system() == "Linux":
|
||||
elif platform.system() == OS.LINUX:
|
||||
cmd = f'shell "{cmd}"'
|
||||
|
||||
# run the command
|
||||
|
|
|
@ -11,6 +11,7 @@ import tempfile
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.common.utils import run_subprocess
|
||||
from olive.evaluator.metric_result import MetricResult
|
||||
from olive.model import ModelConfig
|
||||
|
@ -52,7 +53,7 @@ class PythonEnvironmentSystem(OliveSystem):
|
|||
prepend_to_path=prepend_to_path,
|
||||
)
|
||||
if olive_managed_env:
|
||||
if platform.system() == "Linux":
|
||||
if platform.system() == OS.LINUX:
|
||||
temp_dir = os.path.join(os.environ.get("HOME", ""), "tmp")
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
|
@ -189,7 +190,7 @@ class PythonEnvironmentSystem(OliveSystem):
|
|||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
if platform.system() == "Linux":
|
||||
if platform.system() == OS.LINUX:
|
||||
try:
|
||||
shutil.rmtree(self.environ["TMPDIR"])
|
||||
logger.info("Temporary directory '%s' removed.", self.environ["TMPDIR"])
|
||||
|
|
|
@ -12,6 +12,7 @@ from functools import lru_cache
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.common.utils import hash_dir, run_subprocess
|
||||
from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec
|
||||
from olive.hardware.constants import PROVIDER_DOCKERFILE_MAPPING, PROVIDER_PACKAGE_MAPPING
|
||||
|
@ -56,7 +57,7 @@ def create_managed_system(system_config: "SystemConfig", accelerator: "Accelerat
|
|||
|
||||
from olive.systems.python_environment import PythonEnvironmentSystem
|
||||
|
||||
if platform.system() == "Linux":
|
||||
if platform.system() == OS.LINUX:
|
||||
destination_dir = os.path.join(os.environ.get("HOME", ""), "tmp")
|
||||
if not os.path.exists(destination_dir):
|
||||
os.makedirs(destination_dir)
|
||||
|
@ -67,7 +68,7 @@ def create_managed_system(system_config: "SystemConfig", accelerator: "Accelerat
|
|||
venv.create(venv_path, with_pip=True, system_site_packages=True)
|
||||
logger.info("Virtual environment '%s' created.", venv_path)
|
||||
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
python_environment_path = f"{venv_path}/Scripts"
|
||||
else:
|
||||
python_environment_path = f"{venv_path}/bin"
|
||||
|
|
|
@ -21,6 +21,7 @@ from typing import ClassVar, List
|
|||
|
||||
import pytest
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.evaluator.metric_result import joint_metric_key
|
||||
from olive.hardware import DEFAULT_CPU_ACCELERATOR
|
||||
from olive.model import ModelConfig
|
||||
|
@ -55,7 +56,7 @@ class TestDockerEvaluation:
|
|||
("model_type", "model_config_func", "metric_func", "expected_res"),
|
||||
EVALUATION_TEST_CASE,
|
||||
)
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Docker target does not support windows")
|
||||
@pytest.mark.skipif(platform.system() == OS.WINDOWS, reason="Docker target does not support windows")
|
||||
def test_evaluate_model(self, model_type, model_config_func, metric_func, expected_res):
|
||||
docker_target = get_docker_target()
|
||||
model_config = model_config_func()
|
||||
|
|
|
@ -9,6 +9,7 @@ from test.integ_test.evaluator.docker_eval.utils import (
|
|||
|
||||
import pytest
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.hardware.accelerator import DEFAULT_CPU_ACCELERATOR
|
||||
from olive.logging import set_default_logger_severity
|
||||
from olive.model.config.model_config import ModelConfig
|
||||
|
@ -16,7 +17,7 @@ from olive.passes.olive_pass import create_pass_from_dict
|
|||
from olive.passes.onnx.perf_tuning import OrtPerfTuning
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Docker target does not support windows")
|
||||
@pytest.mark.skipif(platform.system() == OS.WINDOWS, reason="Docker target does not support windows")
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup():
|
||||
get_directories()
|
||||
|
@ -25,7 +26,7 @@ def setup():
|
|||
delete_directories()
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Docker target does not support windows")
|
||||
@pytest.mark.skipif(platform.system() == OS.WINDOWS, reason="Docker target does not support windows")
|
||||
def test_pass_runner(tmp_path):
|
||||
docker_target = get_docker_target()
|
||||
model_config = get_onnx_model()
|
||||
|
|
|
@ -7,12 +7,13 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.model import ModelConfig
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Skip test on Windows.")
|
||||
@pytest.mark.skipif(platform.system() == OS.WINDOWS, reason="Skip test on Windows.")
|
||||
class TestOliveAzureMLSystem:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
|
|
|
@ -8,13 +8,14 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.logging import set_default_logger_severity
|
||||
from olive.model import ModelConfig
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Docker target does not support windows")
|
||||
@pytest.mark.skipif(platform.system() == OS.WINDOWS, reason="Docker target does not support windows")
|
||||
class TestOliveManagedDockerSystem:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
|
|
|
@ -7,6 +7,7 @@ from test.unit_test.utils import create_onnx_model_file, get_custom_metric, get_
|
|||
|
||||
import pytest
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.systems.system_config import PythonEnvironmentTargetUserConfig, SystemConfig
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
@ -39,7 +40,7 @@ class TestOliveManagedPythonEnvironmentSystem:
|
|||
assert dml_res.metrics.value.__root__
|
||||
assert openvino_res.metrics.value.__root__
|
||||
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Test for Linux only")
|
||||
@pytest.mark.skipif(platform.system() == OS.WINDOWS, reason="Test for Linux only")
|
||||
def test_run_pass_evaluate_linux(self, tmp_path):
|
||||
# use the olive managed python environment as the test environment
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import torch
|
|||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.data.config import DataConfig
|
||||
from olive.model import PyTorchModelHandler
|
||||
from olive.passes.olive_pass import create_pass_from_dict
|
||||
|
@ -19,7 +20,7 @@ from olive.passes.onnx.inc_quantization import IncDynamicQuantization, IncQuanti
|
|||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows", reason="Skip test on Windows. neural-compressor import is hanging on Windows."
|
||||
platform.system() == OS.WINDOWS, reason="Skip test on Windows. neural-compressor import is hanging on Windows."
|
||||
)
|
||||
def test_inc_quantization(tmp_path):
|
||||
ov_model = get_onnx_model(tmp_path)
|
||||
|
@ -64,7 +65,7 @@ def test_inc_quantization(tmp_path):
|
|||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows", reason="Skip test on Windows. neural-compressor import is hanging on Windows."
|
||||
platform.system() == OS.WINDOWS, reason="Skip test on Windows. neural-compressor import is hanging on Windows."
|
||||
)
|
||||
def test_inc_weight_only_quantization(tmp_path):
|
||||
ov_model = get_onnx_model(tmp_path)
|
||||
|
@ -95,7 +96,7 @@ def test_inc_weight_only_quantization(tmp_path):
|
|||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows", reason="Skip test on Windows. neural-compressor import is hanging on Windows."
|
||||
platform.system() == OS.WINDOWS, reason="Skip test on Windows. neural-compressor import is hanging on Windows."
|
||||
)
|
||||
@patch.dict("neural_compressor.quantization.STRATEGIES", {"auto": MagicMock()})
|
||||
@patch("olive.passes.onnx.inc_quantization.model_proto_to_olive_model")
|
||||
|
|
|
@ -18,6 +18,7 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.model import PyTorchModelHandler
|
||||
from olive.passes.olive_pass import create_pass_from_dict
|
||||
from olive.passes.onnx.conversion import OnnxConversion, OnnxOpVersionConversion
|
||||
|
@ -38,7 +39,7 @@ def test_onnx_conversion_pass(input_model, tmp_path):
|
|||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows" or not torch.cuda.is_available(),
|
||||
platform.system() == OS.WINDOWS or not torch.cuda.is_available(),
|
||||
reason="bitsandbytes requires Linux GPU.",
|
||||
)
|
||||
@pytest.mark.parametrize("add_quantized_modules", [True, False])
|
||||
|
|
|
@ -12,6 +12,7 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.data.template import huggingface_data_config_template
|
||||
from olive.model import PyTorchModelHandler
|
||||
from olive.passes.olive_pass import create_pass_from_dict
|
||||
|
@ -83,7 +84,7 @@ def test_lora(tmp_path):
|
|||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows" or not torch.cuda.is_available(),
|
||||
platform.system() == OS.WINDOWS or not torch.cuda.is_available(),
|
||||
reason="bitsandbytes requires Linux GPU.",
|
||||
)
|
||||
def test_qlora(tmp_path):
|
||||
|
@ -96,7 +97,7 @@ def test_qlora(tmp_path):
|
|||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows" or not torch.cuda.is_available(),
|
||||
platform.system() == OS.WINDOWS or not torch.cuda.is_available(),
|
||||
reason="bitsandbytes requires Linux GPU.",
|
||||
)
|
||||
def test_loftq(tmp_path):
|
||||
|
|
|
@ -4,6 +4,7 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.passes.olive_pass import create_pass_from_dict
|
||||
from olive.passes.qnn.conversion import QNNConversion
|
||||
|
||||
|
@ -27,7 +28,7 @@ def test_qnn_conversion_cmd(mocked_qnn_sdk_runner, config, tmp_path):
|
|||
mocked_qnn_sdk_runner.return_value.sdk_env.target_arch = "x86_64"
|
||||
p.run(input_model, None, tmp_path)
|
||||
converter_program = ["qnn-onnx-converter"]
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
converter_program = ["python", "sdk_root_path\\bin\\x86_64\\qnn-onnx-converter"]
|
||||
|
||||
expected_cmd_list = [
|
||||
|
|
|
@ -11,6 +11,7 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.platform_sdk.qualcomm.runner import SDKRunner
|
||||
from olive.platform_sdk.qualcomm.snpe.utils.adb import run_adb_command
|
||||
|
||||
|
@ -35,7 +36,7 @@ def test_run_adb_command(mock_run_subprocess, android_target):
|
|||
|
||||
|
||||
def test_run_snpe_command():
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
os.environ["SNPE_ROOT"] = "C:\\snpe"
|
||||
target_arch = "x86_64-windows-msvc"
|
||||
else:
|
||||
|
@ -51,7 +52,7 @@ def test_run_snpe_command():
|
|||
mock_run_subprocess.return_value = CompletedProcess(None, returncode=0, stdout=b"stdout", stderr=b"stderr")
|
||||
runner = SDKRunner(platform="SNPE")
|
||||
stdout, _ = runner.run(cmd="snpe-net-run --container xxxx")
|
||||
if platform.system() == "Linux":
|
||||
if platform.system() == OS.LINUX:
|
||||
env = {
|
||||
"LD_LIBRARY_PATH": "/snpe/lib/x86_64-linux-clang",
|
||||
"PATH": f"/snpe/bin/x86_64-linux-clang:/usr/bin:{os.environ['PATH']}",
|
||||
|
|
|
@ -16,6 +16,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.common.utils import run_subprocess
|
||||
from olive.constants import Framework
|
||||
from olive.evaluator.metric import AccuracySubType, LatencySubType
|
||||
|
@ -107,7 +108,7 @@ class TestIsolatedORTEvaluator:
|
|||
venv_path = tmp_path / "venv"
|
||||
venv.create(venv_path, with_pip=True)
|
||||
# python path
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
python_environment_path = f"{venv_path}/Scripts"
|
||||
else:
|
||||
python_environment_path = f"{venv_path}/bin"
|
||||
|
|
|
@ -19,6 +19,7 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.common.utils import run_subprocess
|
||||
from olive.evaluator.metric_result import MetricResult, joint_metric_key
|
||||
from olive.hardware import DEFAULT_CPU_ACCELERATOR
|
||||
|
@ -39,7 +40,7 @@ class TestPythonEnvironmentSystem:
|
|||
venv_path = tmp_path / "venv"
|
||||
venv.create(venv_path, with_pip=True)
|
||||
# python path
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
self.python_environment_path = Path(venv_path) / "Scripts"
|
||||
else:
|
||||
self.python_environment_path = Path(venv_path) / "bin"
|
||||
|
|
|
@ -12,6 +12,7 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
|
||||
from olive.cache import clean_pass_run_cache, create_cache, download_resource, get_cache_sub_dirs, save_model
|
||||
from olive.common.constants import OS
|
||||
from olive.resource_path import AzureMLModel
|
||||
|
||||
|
||||
|
@ -47,7 +48,7 @@ class TestCache:
|
|||
model_cache_file_path = str((model_cache_dir / "0_p(・◡・)p.json").resolve())
|
||||
with open(model_cache_file_path, "w") as model_cache_file:
|
||||
model_data = f'{{"model_path": "{model_p}"}}'
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
model_data = model_data.replace("\\", "//")
|
||||
model_cache_file.write(model_data)
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from olive.common.constants import OS
|
||||
from olive.common.utils import run_subprocess
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
@ -25,7 +26,7 @@ class DependencySetupEnvBuilder(venv.EnvBuilder):
|
|||
|
||||
@pytest.fixture()
|
||||
def config_json(tmp_path):
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
ep = "DmlExecutionProvider"
|
||||
else:
|
||||
ep = "CUDAExecutionProvider"
|
||||
|
@ -45,7 +46,7 @@ def test_dependency_setup(tmp_path, config_json):
|
|||
builder = DependencySetupEnvBuilder(with_pip=True)
|
||||
builder.create(str(tmp_path))
|
||||
|
||||
if platform.system() == "Windows":
|
||||
if platform.system() == OS.WINDOWS:
|
||||
python_path = tmp_path / "Scripts" / "python"
|
||||
ort_extra = "onnxruntime-directml"
|
||||
else:
|
||||
|
|
Загрузка…
Ссылка в новой задаче