Use `tmp_path` instead of `tmpdir` in pass unit tests (#550)
## Describe your changes Follow up to https://github.com/microsoft/Olive/pull/549. `tmp_path` returns a `pathlib.Path` object which is more convienient. https://docs.pytest.org/en/7.1.x/how-to/tmp_path.html ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Format your code by running `pre-commit run --all-files` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link
This commit is contained in:
Родитель
829cf3e30f
Коммит
b178e2f2c8
|
@ -19,12 +19,12 @@ from olive.passes.onnx.inc_quantization import IncDynamicQuantization, IncQuanti
|
|||
@pytest.mark.skipif(
|
||||
platform.system() == "Windows", reason="Skip test on Windows. neural-compressor import is hanging on Windows."
|
||||
)
|
||||
def test_inc_quantization(tmpdir):
|
||||
ov_model = get_onnx_model(tmpdir)
|
||||
data_dir = Path(tmpdir) / "data"
|
||||
def test_inc_quantization(tmp_path):
|
||||
ov_model = get_onnx_model(tmp_path)
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
config = {"data_dir": data_dir, "dataloader_func": create_dataloader}
|
||||
output_folder = str(Path(tmpdir) / "quantized")
|
||||
output_folder = str(tmp_path / "quantized")
|
||||
|
||||
# create IncQuantization pass
|
||||
p = create_pass_from_dict(IncQuantization, config, disable_search=True)
|
||||
|
@ -61,10 +61,10 @@ def test_inc_quantization(tmpdir):
|
|||
assert "QLinearConv" in [node.op_type for node in quantized_model.load_model().graph.node]
|
||||
|
||||
|
||||
def get_onnx_model(tmpdir):
|
||||
def get_onnx_model(tmp_path):
|
||||
torch_hub_model_path = "chenyaofo/pytorch-cifar-models"
|
||||
pytorch_hub_model_name = "cifar10_mobilenetv2_x1_0"
|
||||
torch.hub.set_dir(tmpdir)
|
||||
torch.hub.set_dir(tmp_path)
|
||||
pytorch_model = PyTorchModel(
|
||||
model_loader=lambda torch_hub_model_path: torch.hub.load(torch_hub_model_path, pytorch_hub_model_name),
|
||||
model_path=torch_hub_model_path,
|
||||
|
@ -73,7 +73,7 @@ def get_onnx_model(tmpdir):
|
|||
onnx_conversion_config = {}
|
||||
|
||||
p = create_pass_from_dict(OnnxConversion, onnx_conversion_config, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "onnx")
|
||||
output_folder = str(tmp_path / "onnx")
|
||||
|
||||
# execute
|
||||
onnx_model = p.run(pytorch_model, None, output_folder)
|
||||
|
|
|
@ -17,11 +17,11 @@ class CustomizedParam:
|
|||
self.params = params
|
||||
|
||||
|
||||
def test_step_parser(tmpdir):
|
||||
def test_step_parser(tmp_path):
|
||||
from onnxruntime_extensions.tools.pre_post_processing import TokenizerParam
|
||||
|
||||
pytorch_model = get_superresolution_model()
|
||||
input_model = convert_superresolution_model(pytorch_model, tmpdir)
|
||||
input_model = convert_superresolution_model(pytorch_model, tmp_path)
|
||||
model = input_model.load_model()
|
||||
|
||||
step_config = Path(__file__).parent / "step_config.json"
|
||||
|
|
|
@ -8,10 +8,10 @@ from olive.passes.onnx.conversion import OnnxConversion
|
|||
|
||||
|
||||
@pytest.mark.parametrize("input_model", [get_pytorch_model(), get_hf_model_with_past()])
|
||||
def test_onnx_conversion_pass(input_model, tmpdir):
|
||||
def test_onnx_conversion_pass(input_model, tmp_path):
|
||||
# setup
|
||||
p = create_pass_from_dict(OnnxConversion, {}, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "onnx")
|
||||
output_folder = str(tmp_path / "onnx")
|
||||
|
||||
# The conversion need torch version > 1.13.1, otherwise, it will complain
|
||||
# Unsupported ONNX opset version: 18
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from pathlib import Path
|
||||
from test.unit_test.utils import get_onnx_model
|
||||
|
||||
from olive.model import CompositeOnnxModel
|
||||
|
@ -6,7 +5,7 @@ from olive.passes.olive_pass import create_pass_from_dict
|
|||
from olive.passes.onnx.insert_beam_search import InsertBeamSearch
|
||||
|
||||
|
||||
def test_insert_beam_search_pass(tmpdir):
|
||||
def test_insert_beam_search_pass(tmp_path):
|
||||
# setup
|
||||
input_models = []
|
||||
input_models.append(get_onnx_model())
|
||||
|
@ -18,7 +17,7 @@ def test_insert_beam_search_pass(tmpdir):
|
|||
)
|
||||
|
||||
p = create_pass_from_dict(InsertBeamSearch, {}, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "onnx")
|
||||
output_folder = str(tmp_path / "onnx")
|
||||
|
||||
# execute
|
||||
p.run(composite_model, None, output_folder)
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
from pathlib import Path
|
||||
from test.unit_test.utils import get_onnx_model
|
||||
|
||||
from olive.passes.olive_pass import create_pass_from_dict
|
||||
from olive.passes.onnx.mixed_precision import OrtMixedPrecision
|
||||
|
||||
|
||||
def test_ort_mixed_precision_pass(tmpdir):
|
||||
def test_ort_mixed_precision_pass(tmp_path):
|
||||
# setup
|
||||
input_model = get_onnx_model()
|
||||
p = create_pass_from_dict(OrtMixedPrecision, {}, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "onnx")
|
||||
output_folder = str(tmp_path / "onnx")
|
||||
|
||||
# execute
|
||||
p.run(input_model, None, output_folder)
|
||||
|
|
|
@ -2,18 +2,17 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from test.unit_test.utils import get_onnx_model
|
||||
|
||||
from olive.passes.olive_pass import create_pass_from_dict
|
||||
from olive.passes.onnx import OnnxModelOptimizer
|
||||
|
||||
|
||||
def test_onnx_model_optimizer_pass(tmpdir):
|
||||
def test_onnx_model_optimizer_pass(tmp_path):
|
||||
# setup
|
||||
input_model = get_onnx_model()
|
||||
p = create_pass_from_dict(OnnxModelOptimizer, {}, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "onnx")
|
||||
output_folder = str(tmp_path / "onnx")
|
||||
|
||||
# execute
|
||||
p.run(input_model, None, output_folder)
|
||||
|
|
|
@ -8,10 +8,10 @@ from olive.passes.onnx.optimum_conversion import OptimumConversion
|
|||
|
||||
|
||||
@pytest.mark.parametrize("input_model", [get_optimum_model_by_hf_config(), get_optimum_model_by_model_path()])
|
||||
def test_optimum_conversion_pass(input_model, tmpdir):
|
||||
def test_optimum_conversion_pass(input_model, tmp_path):
|
||||
# setup
|
||||
p = create_pass_from_dict(OptimumConversion, {}, disable_search=True)
|
||||
output_folder = Path(tmpdir)
|
||||
output_folder = tmp_path
|
||||
|
||||
# execute
|
||||
onnx_model = p.run(input_model, None, output_folder)
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from test.unit_test.utils import get_onnx_model
|
||||
from unittest.mock import patch
|
||||
|
||||
|
@ -13,18 +12,18 @@ from olive.passes.onnx import OrtPerfTuning
|
|||
|
||||
|
||||
@pytest.mark.parametrize("config", [{"input_names": ["input"], "input_shapes": [[1, 1]]}, {}])
|
||||
def test_ort_perf_tuning_pass(config, tmpdir):
|
||||
def test_ort_perf_tuning_pass(config, tmp_path):
|
||||
# setup
|
||||
input_model = get_onnx_model()
|
||||
p = create_pass_from_dict(OrtPerfTuning, config, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "onnx")
|
||||
output_folder = str(tmp_path / "onnx")
|
||||
|
||||
# execute
|
||||
p.run(input_model, None, output_folder)
|
||||
|
||||
|
||||
@patch("olive.model.ONNXModel.get_io_config")
|
||||
def test_ort_perf_tuning_pass_with_dynamic_shapes(mock_get_io_config, tmpdir):
|
||||
def test_ort_perf_tuning_pass_with_dynamic_shapes(mock_get_io_config, tmp_path):
|
||||
mock_get_io_config.return_value = {
|
||||
"input_names": ["input"],
|
||||
"input_shapes": [["input_0", "input_1"]],
|
||||
|
@ -36,7 +35,7 @@ def test_ort_perf_tuning_pass_with_dynamic_shapes(mock_get_io_config, tmpdir):
|
|||
|
||||
input_model = get_onnx_model()
|
||||
p = create_pass_from_dict(OrtPerfTuning, {}, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "onnx")
|
||||
output_folder = str(tmp_path / "onnx")
|
||||
|
||||
with pytest.raises(TypeError) as e:
|
||||
# execute
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
from pathlib import Path
|
||||
|
||||
from olive.model import ONNXModel, PyTorchModel
|
||||
from olive.passes.olive_pass import create_pass_from_dict
|
||||
from olive.passes.onnx.append_pre_post_processing_ops import AppendPrePostProcessingOps
|
||||
from olive.passes.onnx.conversion import OnnxConversion
|
||||
|
||||
|
||||
def test_pre_post_processing_op(tmpdir):
|
||||
def test_pre_post_processing_op(tmp_path):
|
||||
# setup
|
||||
p = create_pass_from_dict(
|
||||
AppendPrePostProcessingOps,
|
||||
|
@ -15,14 +13,14 @@ def test_pre_post_processing_op(tmpdir):
|
|||
)
|
||||
|
||||
pytorch_model = get_superresolution_model()
|
||||
input_model = convert_superresolution_model(pytorch_model, tmpdir)
|
||||
output_folder = str(Path(tmpdir) / "onnx")
|
||||
input_model = convert_superresolution_model(pytorch_model, tmp_path)
|
||||
output_folder = str(tmp_path / "onnx")
|
||||
|
||||
# execute
|
||||
p.run(input_model, None, output_folder)
|
||||
|
||||
|
||||
def test_pre_post_pipeline(tmpdir):
|
||||
def test_pre_post_pipeline(tmp_path):
|
||||
config = {
|
||||
"pre": [
|
||||
{"ConvertImageToBGR": {}},
|
||||
|
@ -105,10 +103,10 @@ def test_pre_post_pipeline(tmpdir):
|
|||
assert p is not None
|
||||
|
||||
pytorch_model = get_superresolution_model()
|
||||
input_model = convert_superresolution_model(pytorch_model, tmpdir)
|
||||
input_model = convert_superresolution_model(pytorch_model, tmp_path)
|
||||
input_model_graph = input_model.get_graph()
|
||||
assert input_model_graph.node[0].op_type == "Conv"
|
||||
output_folder = str(Path(tmpdir) / "onnx_pre_post")
|
||||
output_folder = str(tmp_path / "onnx_pre_post")
|
||||
|
||||
# execute
|
||||
model = p.run(input_model, None, output_folder)
|
||||
|
@ -182,8 +180,8 @@ def get_superresolution_model():
|
|||
return pytorch_model
|
||||
|
||||
|
||||
def convert_superresolution_model(pytorch_model, tmpdir):
|
||||
def convert_superresolution_model(pytorch_model, tmp_path):
|
||||
onnx_conversion_pass = create_pass_from_dict(OnnxConversion, {"target_opset": 15}, disable_search=True)
|
||||
onnx_model = onnx_conversion_pass.run(pytorch_model, None, str(Path(tmpdir) / "onnx"))
|
||||
onnx_model = onnx_conversion_pass.run(pytorch_model, None, str(tmp_path / "onnx"))
|
||||
|
||||
return onnx_model
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from test.unit_test.utils import get_onnx_model
|
||||
|
||||
import pytest
|
||||
|
@ -37,14 +36,14 @@ def test_fusion_options():
|
|||
assert vars(olive_fusion_options) == vars(ort_fusion_options)
|
||||
|
||||
|
||||
def test_ort_transformer_optimization_pass(tmpdir):
|
||||
def test_ort_transformer_optimization_pass(tmp_path):
|
||||
# setup
|
||||
input_model = get_onnx_model()
|
||||
config = {"model_type": "bert"}
|
||||
|
||||
config = OrtTransformersOptimization.generate_search_space(DEFAULT_CPU_ACCELERATOR, config, disable_search=True)
|
||||
p = OrtTransformersOptimization(DEFAULT_CPU_ACCELERATOR, config, True)
|
||||
output_folder = str(Path(tmpdir) / "onnx")
|
||||
output_folder = str(tmp_path / "onnx")
|
||||
|
||||
# execute
|
||||
p.run(input_model, None, output_folder)
|
||||
|
@ -55,7 +54,7 @@ def test_ort_transformer_optimization_pass(tmpdir):
|
|||
@pytest.mark.parametrize(
|
||||
"accelerator_spec", [DEFAULT_CPU_ACCELERATOR, DEFAULT_GPU_CUDA_ACCELERATOR, DEFAULT_GPU_TRT_ACCELERATOR]
|
||||
)
|
||||
def test_invalid_ep_config(use_gpu, fp16, accelerator_spec, tmpdir):
|
||||
def test_invalid_ep_config(use_gpu, fp16, accelerator_spec, tmp_path):
|
||||
input_model = get_onnx_model()
|
||||
config = {"model_type": "bert", "use_gpu": use_gpu, "float16": fp16}
|
||||
config = OrtTransformersOptimization.generate_search_space(accelerator_spec, config, disable_search=True)
|
||||
|
@ -74,5 +73,5 @@ def test_invalid_ep_config(use_gpu, fp16, accelerator_spec, tmpdir):
|
|||
)
|
||||
|
||||
if not is_pruned:
|
||||
output_folder = str(Path(tmpdir) / "onnx")
|
||||
output_folder = str(tmp_path / "onnx")
|
||||
p.run(input_model, None, output_folder)
|
||||
|
|
|
@ -9,14 +9,14 @@ from olive.passes.olive_pass import create_pass_from_dict
|
|||
from olive.passes.openvino.conversion import OpenVINOConversion
|
||||
|
||||
|
||||
def test_openvino_conversion_pass(tmpdir):
|
||||
def test_openvino_conversion_pass(tmp_path):
|
||||
# setup
|
||||
input_model = get_pytorch_model()
|
||||
dummy_input = get_pytorch_model_dummy_input(input_model)
|
||||
openvino_conversion_config = {"extra_config": {"example_input": dummy_input}}
|
||||
|
||||
p = create_pass_from_dict(OpenVINOConversion, openvino_conversion_config, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "openvino")
|
||||
output_folder = str(tmp_path / "openvino")
|
||||
|
||||
# execute
|
||||
openvino_model = p.run(input_model, None, output_folder)
|
||||
|
@ -27,7 +27,7 @@ def test_openvino_conversion_pass(tmpdir):
|
|||
assert (Path(openvino_model.model_path) / "ov_model.xml").is_file()
|
||||
|
||||
|
||||
def test_openvino_conversion_pass_no_example_input(tmpdir):
|
||||
def test_openvino_conversion_pass_no_example_input(tmp_path):
|
||||
# setup
|
||||
input_model = get_pytorch_model()
|
||||
openvino_conversion_config = {
|
||||
|
@ -35,7 +35,7 @@ def test_openvino_conversion_pass_no_example_input(tmpdir):
|
|||
}
|
||||
|
||||
p = create_pass_from_dict(OpenVINOConversion, openvino_conversion_config, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "openvino")
|
||||
output_folder = str(tmp_path / "openvino")
|
||||
|
||||
# execute
|
||||
openvino_model = p.run(input_model, None, output_folder)
|
||||
|
|
|
@ -19,10 +19,10 @@ from olive.passes.openvino.quantization import OpenVINOQuantization
|
|||
|
||||
|
||||
@pytest.mark.parametrize("data_source", ["dataloader_func", "data_config"])
|
||||
def test_openvino_quantization(data_source, tmpdir):
|
||||
def test_openvino_quantization(data_source, tmp_path):
|
||||
# setup
|
||||
ov_model = get_openvino_model(tmpdir)
|
||||
data_dir = Path(tmpdir) / "data"
|
||||
ov_model = get_openvino_model(tmp_path)
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
config = {
|
||||
"engine_config": {"device": "CPU"},
|
||||
|
@ -60,7 +60,7 @@ def test_openvino_quantization(data_source, tmpdir):
|
|||
disable_search=True,
|
||||
accelerator_spec=AcceleratorSpec("cpu", "OpenVINOExecutionProvider"),
|
||||
)
|
||||
output_folder = str(Path(tmpdir) / "quantized")
|
||||
output_folder = str(tmp_path / "quantized")
|
||||
|
||||
# execute
|
||||
quantized_model = p.run(ov_model, None, output_folder)
|
||||
|
@ -72,10 +72,10 @@ def test_openvino_quantization(data_source, tmpdir):
|
|||
assert (Path(quantized_model.model_path) / "ov_model.mapping").is_file()
|
||||
|
||||
|
||||
def get_openvino_model(tmpdir):
|
||||
def get_openvino_model(tmp_path):
|
||||
torch_hub_model_path = "chenyaofo/pytorch-cifar-models"
|
||||
pytorch_hub_model_name = "cifar10_mobilenetv2_x1_0"
|
||||
torch.hub.set_dir(tmpdir)
|
||||
torch.hub.set_dir(tmp_path)
|
||||
pytorch_model = PyTorchModel(
|
||||
model_loader=lambda torch_hub_model_path: torch.hub.load(torch_hub_model_path, pytorch_hub_model_name),
|
||||
model_path=torch_hub_model_path,
|
||||
|
@ -90,7 +90,7 @@ def get_openvino_model(tmpdir):
|
|||
disable_search=True,
|
||||
accelerator_spec=AcceleratorSpec("cpu", "OpenVINOExecutionProvider"),
|
||||
)
|
||||
output_folder = str(Path(tmpdir) / "openvino")
|
||||
output_folder = str(tmp_path / "openvino")
|
||||
|
||||
# execute
|
||||
openvino_model = p.run(pytorch_model, None, output_folder)
|
||||
|
|
|
@ -18,7 +18,7 @@ def patched_find_all_linear_names(model):
|
|||
# quantization requires gpu so we will patch the model loading args with no quantization
|
||||
@patch("olive.passes.pytorch.qlora.HFModelLoadingArgs")
|
||||
@patch("olive.passes.pytorch.qlora.QLoRA.find_all_linear_names", side_effect=patched_find_all_linear_names)
|
||||
def test_qlora(patched_model_loading_args, patched_find_all_linear_names, tmpdir):
|
||||
def test_qlora(patched_model_loading_args, patched_find_all_linear_names, tmp_path):
|
||||
# setup
|
||||
model_name = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||
task = "text-generation"
|
||||
|
@ -54,7 +54,7 @@ def test_qlora(patched_model_loading_args, patched_find_all_linear_names, tmpdir
|
|||
}
|
||||
|
||||
p = create_pass_from_dict(QLoRA, config, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "qlora")
|
||||
output_folder = str(tmp_path / "qlora")
|
||||
|
||||
# execute
|
||||
out = p.run(input_model, None, output_folder)
|
||||
|
|
|
@ -2,23 +2,22 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
from test.unit_test.utils import create_dataloader, get_pytorch_model
|
||||
|
||||
from olive.passes.olive_pass import create_pass_from_dict
|
||||
from olive.passes.pytorch import QuantizationAwareTraining
|
||||
|
||||
|
||||
def test_quantization_aware_training_pass_default(tmpdir):
|
||||
def test_quantization_aware_training_pass_default(tmp_path):
|
||||
# setup
|
||||
input_model = get_pytorch_model()
|
||||
config = {
|
||||
"train_dataloader_func": create_dataloader,
|
||||
"checkpoint_path": str(Path(tmpdir) / "checkpoint"),
|
||||
"checkpoint_path": str(tmp_path / "checkpoint"),
|
||||
}
|
||||
|
||||
p = create_pass_from_dict(QuantizationAwareTraining, config, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "onnx")
|
||||
output_folder = str(tmp_path / "onnx")
|
||||
|
||||
# execute
|
||||
p.run(input_model, None, output_folder)
|
||||
|
|
|
@ -2,15 +2,13 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from pathlib import Path
|
||||
|
||||
from olive.data.template import huggingface_data_config_template
|
||||
from olive.model import PyTorchModel
|
||||
from olive.passes.olive_pass import create_pass_from_dict
|
||||
from olive.passes.pytorch import SparseGPT
|
||||
|
||||
|
||||
def test_sparsegpt(tmpdir):
|
||||
def test_sparsegpt(tmp_path):
|
||||
# setup
|
||||
model_name = "sshleifer/tiny-gpt2"
|
||||
task = "text-generation"
|
||||
|
@ -37,7 +35,7 @@ def test_sparsegpt(tmpdir):
|
|||
}
|
||||
|
||||
p = create_pass_from_dict(SparseGPT, config, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "sparse")
|
||||
output_folder = str(tmp_path / "sparse")
|
||||
|
||||
# execute
|
||||
p.run(input_model, None, output_folder)
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
@ -36,7 +35,7 @@ def mocked_torch_zeros(*args, **kwargs):
|
|||
# replace device in kwargs with "cpu"
|
||||
@patch("torch.zeros", side_effect=mocked_torch_zeros)
|
||||
def test_torch_trt_conversion_success(
|
||||
mock_torch_zeros, mock_torch_nn_module_to, mock_tensor_data_to_device, mock_torch_cuda_is_available, tmpdir
|
||||
mock_torch_zeros, mock_torch_nn_module_to, mock_tensor_data_to_device, mock_torch_cuda_is_available, tmp_path
|
||||
):
|
||||
# setup
|
||||
# mock trt utils since we don't have tensorrt and torch-tensorrt installed
|
||||
|
@ -79,7 +78,7 @@ def test_torch_trt_conversion_success(
|
|||
}
|
||||
|
||||
p = create_pass_from_dict(TorchTRTConversion, config, disable_search=True)
|
||||
output_folder = str(Path(tmpdir) / "sparse")
|
||||
output_folder = str(tmp_path / "sparse")
|
||||
|
||||
# execute
|
||||
model = p.run(input_model, None, output_folder)
|
||||
|
|
|
@ -34,18 +34,18 @@ def dummy_calibration_reader(data_dir=None, batch_size=1, *args, **kwargs):
|
|||
return RandomDataReader()
|
||||
|
||||
|
||||
def test_vitis_ai_quantization_pass(tmpdir):
|
||||
def test_vitis_ai_quantization_pass(tmp_path):
|
||||
# setup
|
||||
input_model = get_onnx_model()
|
||||
dummy_user_script = str(Path(tmpdir) / "dummy_user_script.py")
|
||||
dummy_data = str(Path(tmpdir) / "dummy_data")
|
||||
dummy_user_script = str(tmp_path / "dummy_user_script.py")
|
||||
dummy_data = str(tmp_path / "dummy_data")
|
||||
with open(dummy_user_script, "w") as f:
|
||||
f.write(" ")
|
||||
if not os.path.exists(dummy_data):
|
||||
os.mkdir(dummy_data)
|
||||
|
||||
config = {"user_script": dummy_user_script, "data_dir": dummy_data, "dataloader_func": dummy_calibration_reader}
|
||||
output_folder = str(Path(tmpdir) / "vitis_ai_quantized")
|
||||
output_folder = str(tmp_path / "vitis_ai_quantized")
|
||||
|
||||
# create VitisAIQuantization pass
|
||||
p = create_pass_from_dict(VitisAIQuantization, config, disable_search=True)
|
||||
|
|
Загрузка…
Ссылка в новой задаче