add editorconfig rules & enable pylint (#659)

## Describe your changes

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [x] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Format your code by running `pre-commit run --all-files`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

## (Optional) Issue link
This commit is contained in:
Mike Guo 2023-10-20 13:43:00 +08:00 коммит произвёл GitHub
Родитель 2b5aef171a
Коммит 4df44a5f45
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
116 изменённых файлов: 449 добавлений и 253 удалений

Просмотреть файл

@ -240,10 +240,10 @@ def run_perf_comparison(cur_dir, model_name, device, model_root_path, test_num):
olive_config_path = cur_dir / "configs" / olive_config
run_with_config("olive", olive_config_path, metric_res)
print(f"All metric results {metric_res}")
for model, v in metric_res.items():
for v in metric_res.values():
for metric_name, metric_value_list in v.items():
vsum = sum(float(v) for v in metric_value_list)
metric_res[model][metric_name] = round((vsum / len(metric_value_list)), 4)
v[metric_name] = round((vsum / len(metric_value_list)), 4)
print(f"Avg metric results {metric_res}")
return metric_res

7
.editorconfig Normal file
Просмотреть файл

@ -0,0 +1,7 @@
root = true
[*]
trim_trailing_whitespace = true
insert_final_newline = true
indent_style = space

Просмотреть файл

@ -57,13 +57,41 @@ init_command = [
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--requirement=requirements-lintrunner.txt'
'--requirement=requirements-dev.txt'
]
is_formatter = true
[[linter]]
code = 'BLACK-ISORT'
include_patterns = [
'**/*.py'
]
exclude_patterns = [
]
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'black_isort_linter',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--requirement=requirements-dev.txt'
]
is_formatter = true
[[linter]]
code = 'PYLINT'
include_patterns = [
'**/*.py'
]
exclude_patterns = [
]
@ -84,7 +112,7 @@ init_command = [
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--requirement=requirements-lintrunner.txt'
'--requirement=requirements-dev.txt'
]
[[linter]]
@ -208,3 +236,43 @@ init_command = [
'--dry-run={{DRYRUN}}',
'toml-sort==0.23.1'
]
[[linter]]
code = 'EDITORCONFIG-CHECKER'
include_patterns = ['**']
exclude_patterns = [
'**/*.ipynb'
]
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'editorconfig_checker_linter',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'--requirement=requirements-dev.txt'
]
[[linter]]
code = 'REQUIREMENTS-TXT'
is_formatter = true
include_patterns = ['requirements*.txt']
exclude_patterns = []
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'requirements_txt_linter',
'--',
'@{{PATHSFILE}}'
]

Просмотреть файл

@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
@ -9,12 +9,12 @@ repos:
- id: requirements-txt-fixer
- repo: https://github.com/psf/black
rev: 22.3.0
rev: 23.7.0
hooks:
- id: black
name: Format code
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
rev: 6.1.0
hooks:
- id: flake8
name: Check PEP8
@ -29,6 +29,6 @@ repos:
- id: absolufy-imports
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.0.291
rev: v0.1.0
hooks:
- id: ruff

Просмотреть файл

@ -8,6 +8,7 @@ import sys
import sphinx_rtd_theme
# ruff: noqa
# pylint: skip-file
sys.path.append(os.path.abspath("exts"))
# Configuration file for the Sphinx documentation builder.
#

Просмотреть файл

@ -17,6 +17,8 @@ from olive.common.auto_config import AutoConfigClass
from olive.hardware import DEFAULT_CPU_ACCELERATOR
from olive.passes import Pass
# pylint: skip-file
def import_class(class_name: str):
module_name = ".".join(class_name.split(".")[:-1])

Просмотреть файл

@ -29,6 +29,7 @@ from olive.model import OliveModel
datasets_logging.disable_progress_bar()
datasets_logging.set_verbosity_error()
# pylint: disable=attribute-defined-outside-init, protected-access
# This file is only used by bert_inc_ptq_cpu, bert_qat_customized_train_loop_cpu
# -------------------------------------------------------------------------
@ -178,7 +179,7 @@ class IncBertDataset:
return input_dict, label
def inc_glue_calibration_reader(data_dir, batch_size=1, *args, **kwargs):
def inc_glue_calibration_reader(data_dir, batch_size, *args, **kwargs):
bert_dataset = BertDataset("Intel/bert-base-uncased-mrpc")
bert_dataset = IncBertDataset(bert_dataset.get_eval_dataset())
return DefaultDataLoader(dataset=bert_dataset, batch_size=batch_size)

Просмотреть файл

@ -21,6 +21,7 @@ def main():
file_template_content = f.read()
file_template_content = file_template_content.replace("{USER_SCRIPT}", user_script_path)
# pylint: disable=consider-using-with
config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json", prefix="config_")
with open(config_file.name, "w") as f: # noqa: PTH123
f.write(file_template_content)

Просмотреть файл

@ -51,9 +51,7 @@ class CifarDataLoader(DataLoader):
raise IndexError
return (
self.pictures[index].numpy()[
None,
],
self.pictures[index].numpy()[None,],
self.labels[index],
)

Просмотреть файл

@ -6,6 +6,7 @@ import argparse
import json
import os
import shutil
import sys
from pathlib import Path
import config
@ -15,6 +16,8 @@ from packaging import version
from olive.model import CompositeOnnxModel, ONNXModel
from olive.workflows import run as olive_run
# pylint: disable=redefined-outer-name
def optimize(model_name: str, optimized_model_dir: Path):
from google.protobuf import __version__ as protobuf_version
@ -22,7 +25,7 @@ def optimize(model_name: str, optimized_model_dir: Path):
# protobuf 4.x aborts with OOM when optimizing dolly
if version.parse(protobuf_version) > version.parse("3.20.3"):
print("This script requires protobuf 3.20.3. Please ensure your package version matches requirements.txt.")
exit(1)
sys.exit(1)
ort.set_default_logger_severity(4)
script_dir = Path(__file__).resolve().parent
@ -103,7 +106,7 @@ if __name__ == "__main__":
"databricks/dolly-v2-7b": 4096,
}
if args.model not in list(model_to_hidden_size.keys()):
if args.model not in model_to_hidden_size:
print(
f"WARNING: {args.model} is not an officially supported model for this example and may not work as expected."
)

Просмотреть файл

@ -5,6 +5,7 @@
import argparse
import json
import shutil
import sys
import threading
import tkinter as tk
import tkinter.ttk as ttk
@ -22,6 +23,8 @@ from user_script import get_base_model_name
from olive.model import ONNXModel
from olive.workflows import run as olive_run
# pylint: disable=redefined-outer-name
def run_inference_loop(
pipeline, prompt, num_images, batch_size, image_size, num_inference_steps, image_callback=None, step_callback=None
@ -171,7 +174,7 @@ def optimize(
# protobuf 4.x aborts with OOM when optimizing unet
if version.parse(protobuf_version) > version.parse("3.20.3"):
print("This script requires protobuf 3.20.3. Please ensure your package version matches requirements.txt.")
exit(1)
sys.exit(1)
ort.set_default_logger_severity(4)
script_dir = Path(__file__).resolve().parent
@ -325,7 +328,7 @@ if __name__ == "__main__":
"stabilityai/stable-diffusion-2-1-base": 768,
}
if args.model_id not in list(model_to_image_size.keys()):
if args.model_id not in model_to_image_size:
print(
f"WARNING: {args.model_id} is not an officially supported model for this example and may not work "
"as expected."
@ -333,7 +336,7 @@ if __name__ == "__main__":
if version.parse(ort.__version__) < version.parse("1.15.0"):
print("This script requires onnxruntime-directml 1.15.0 or newer")
exit(1)
sys.exit(1)
script_dir = Path(__file__).resolve().parent
unoptimized_model_dir = script_dir / "models" / "unoptimized" / args.model_id

Просмотреть файл

@ -5,6 +5,7 @@
import argparse
import json
import shutil
import sys
import threading
import tkinter as tk
import tkinter.ttk as ttk
@ -23,6 +24,8 @@ from PIL import Image, ImageTk
from olive.model import ONNXModel
from olive.workflows import run as olive_run
# pylint: disable=redefined-outer-name
def run_inference_loop(
pipeline,
@ -237,7 +240,7 @@ def optimize(
# protobuf 4.x aborts with OOM when optimizing unet
if version.parse(protobuf_version) > version.parse("3.20.3"):
print("This script requires protobuf 3.20.3. Please ensure your package version matches requirements.txt.")
exit(1)
sys.exit(1)
ort.set_default_logger_severity(4)
script_dir = Path(__file__).resolve().parent
@ -431,7 +434,7 @@ if __name__ == "__main__":
},
}
if args.model_id not in list(model_to_config.keys()):
if args.model_id not in model_to_config:
print(
f"WARNING: {args.model_id} is not an officially supported model for this example and may not work as "
"expected."
@ -439,7 +442,7 @@ if __name__ == "__main__":
if version.parse(ort.__version__) < version.parse("1.15.0"):
print("This script requires onnxruntime-directml 1.15.0 or newer")
exit(1)
sys.exit(1)
script_dir = Path(__file__).resolve().parent
@ -458,11 +461,11 @@ if __name__ == "__main__":
if is_refiner_model and not args.optimize and args.base_images is None:
print("--base_images needs to be provided when executing a refiner model without --optimize")
exit(1)
sys.exit(1)
if not is_refiner_model and args.base_images is not None:
print("--base_images should only be provided for refiner models")
exit(1)
sys.exit(1)
if args.optimize or not optimized_model_dir.exists():
# TODO(PatriceVignola): clean up warning filter (mostly during conversion from torch to ONNX)

Просмотреть файл

@ -14,6 +14,8 @@ from olive.constants import Framework
ort.set_default_logger_severity(3)
# pylint: disable=not-callable, useless-parent-delegation
def tokenize_function(examples):
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
@ -84,7 +86,7 @@ def create_pt_dataloader(data_dir, batch_size, *args, **kwargs):
return Dataloader(batch_size=batch_size)
def create_onnx_dataloader(data_dir, batch_size=1, *args, **kwargs):
def create_onnx_dataloader(data_dir, batch_size, *args, **kwargs):
return OnnxDataloader(batch_size=batch_size)

Просмотреть файл

@ -14,6 +14,8 @@ from torchvision import transforms
from olive.common.utils import run_subprocess
# pylint: disable=consider-using-with
def get_directories():
current_dir = Path(__file__).resolve().parent

Просмотреть файл

@ -49,7 +49,7 @@ class MobileNetCalibrationDataReader(CalibrationDataReader):
self.iter = None
def evaluation_dataloader(data_dir, batch_size=1, *args, **kwargs):
def evaluation_dataloader(data_dir, batch_size, *args, **kwargs):
dataset = MobileNetDataset(data_dir)
return DataLoader(dataset, batch_size=batch_size)
@ -58,5 +58,5 @@ def post_process(output):
return output.argmax(axis=1)
def mobilenet_calibration_reader(data_dir, batch_size=1, *args, **kwargs):
def mobilenet_calibration_reader(data_dir, batch_size, *args, **kwargs):
return MobileNetCalibrationDataReader(data_dir, batch_size=batch_size)

Просмотреть файл

@ -124,7 +124,7 @@ class PileDataloader:
return
def calib_dataloader(data_dir, batch_size=1, *args, **kwargs):
def calib_dataloader(data_dir, batch_size, *args, **kwargs):
model_path = kwargs.pop("model_path")
return PileDataloader(model_path, batch_size=batch_size)

Просмотреть файл

@ -151,9 +151,8 @@ def main():
data_download_path = data_dir / "cifar-10-python.tar.gz"
urllib.request.urlretrieve("https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz", data_download_path)
file = tarfile.open(data_download_path)
file.extractall(data_dir)
file.close()
with tarfile.open(data_download_path) as file:
file.extractall(data_dir)
prepare_model(args.num_epochs, models_dir, data_dir)

Просмотреть файл

@ -112,7 +112,7 @@ class ResnetCalibrationDataReader(CalibrationDataReader):
return None
def resnet_calibration_reader(data_dir, batch_size=16, *args, **kwargs):
def resnet_calibration_reader(data_dir, batch_size, *args, **kwargs):
return ResnetCalibrationDataReader(data_dir, batch_size=batch_size)

Просмотреть файл

@ -13,6 +13,8 @@ from torchvision import transforms
from olive.common.utils import run_subprocess
# pylint: skip-file
def get_directories():
current_dir = Path(__file__).resolve().parent

Просмотреть файл

@ -109,7 +109,7 @@ def _main():
if not np.all(list(results.values())):
pprint.pprint(results)
raise Exception("Inference tests failed!")
raise Exception("Inference tests failed!") # pylint: disable=broad-exception-raised
print("Inference test completed successfully!")
return 0

Просмотреть файл

@ -25,7 +25,6 @@ def setup():
@pytest.mark.parametrize("olive_json", ["bert_cuda_gpu.json"])
@pytest.mark.parametrize("enable_cuda_graph", [True, False])
def test_bert(search_algorithm, execution_order, system, olive_json, enable_cuda_graph):
from olive.workflows import run as olive_run
olive_config = patch_config(olive_json, search_algorithm, execution_order, system, is_gpu=True)

Просмотреть файл

@ -24,7 +24,6 @@ def setup():
@pytest.mark.parametrize("system", ["local_system"])
@pytest.mark.parametrize("olive_json", ["bert_ptq_cpu.json"])
def test_bert(search_algorithm, execution_order, system, olive_json):
from olive.workflows import run as olive_run
olive_config = patch_config(olive_json, search_algorithm, execution_order, system)

Просмотреть файл

@ -40,6 +40,11 @@ def test_bert(olive_test_knob):
if olive_test_knob[3] == "aml_system":
# remove the invalid OpenVINOExecutionProvider for bert aml system.
olive_config["engine"]["execution_providers"] = ["CPUExecutionProvider"]
# remove goal for aml system since sometimes the aml job will be reused.
# If the jobs perf cannot meet the goal, the test will fail definitely.
metrics = olive_config["evaluators"]["common_evaluator"]["metrics"]
metrics[0]["sub_types"][0].pop("goal", None)
metrics[1]["sub_types"][0].pop("goal", None)
output = olive_run(olive_config)
check_output(output)

Просмотреть файл

@ -5,6 +5,8 @@
import json
import os
# pylint: skip-file
def check_output(footprints):
"""Check if the search output is valid."""

Просмотреть файл

@ -164,9 +164,9 @@ def decoder_dummy_inputs(model):
return tuple(inputs.to_list())
def whisper_audio_decoder_dataloader(data_dir, batch_size=None, *args, **kwargs):
def whisper_audio_decoder_dataloader(data_dir, batch_size, *args, **kwargs):
return WhisperDataset(data_dir=data_dir, use_audio_decoder=True)
def whisper_no_audio_decoder_dataloader(data_dir, batch_size=None, *args, **kwargs):
def whisper_no_audio_decoder_dataloader(data_dir, batch_size, *args, **kwargs):
return WhisperDataset(data_dir=data_dir, use_audio_decoder=False)

Просмотреть файл

@ -71,7 +71,7 @@ class WhisperDataset:
def __getitem__(self, idx):
data = self.data[idx]
label = self.labels[idx] if self.labels is not None else -1
label = self.labels[idx] if self.labels is not None else -1 # pylint: disable=unsubscriptable-object
return data, label
def __iter__(self):

Просмотреть файл

@ -18,6 +18,7 @@ from olive.model import ONNXModel
sys.path.append(str(Path(__file__).parent / "code"))
# pylint: disable=wrong-import-position, wrong-import-order
from whisper_dataset import WhisperDataset # noqa: E402
@ -80,7 +81,7 @@ def main(raw_args=None):
args.audio_path = download_audio_test_data()
# temporary directory for storing audio file
temp_dir = tempfile.TemporaryDirectory()
temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
temp_dir_path = Path(temp_dir.name)
temp_audio_path = temp_dir_path / Path(args.audio_path).name
shutil.copy(args.audio_path, temp_audio_path)

Просмотреть файл

@ -9,10 +9,10 @@ from typing import List, Optional, Union
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import Dataset as TorchDataset
class BaseDataset(Dataset):
class BaseDataset(TorchDataset):
"""Define the Olive dataset which should return the data with following format.
1. [data, label] for supervised learning
@ -80,7 +80,7 @@ class BaseDataset(Dataset):
data_dict = {k: [] for k in first_input}
data_dict[label_name] = []
# loop over the dataset
for i in range(len(self)):
for i in range(len(self)): # pylint: disable=consider-using-enumerate
data, label = deepcopy(self[i])
for k, v in data.items():
data_dict[k].append(v)
@ -98,6 +98,7 @@ class DummyDataset(BaseDataset):
if input_names is None, the dummy dataset will return a tuple of tensors
else the dummy dataset will return a dict of tensors
"""
# pylint: disable=super-init-not-called
self.input_shapes = input_shapes
self.input_names = input_names
self.input_types = input_types or ["float32"] * len(input_shapes)
@ -163,6 +164,7 @@ class RawDataset(BaseDataset):
:param annotations_file: Name of the file containing the annotations. This file should be present in the
data_dir. It is assumed to be a .npy file containing a numpy array. Default is None.
"""
# pylint: disable=super-init-not-called
self.data_dir = Path(data_dir).resolve()
self.input_names = input_names
assert len(input_names) == len(input_shapes), "Number of input shapes should be equal to number of input names."

Просмотреть файл

@ -81,7 +81,7 @@ class TextGenParams(ConfigBase):
@validator("drop_short_sequences", always=True)
def _check_padding(cls, v, values):
if "pad_to_max_len" not in values:
ValueError("Invalid pad_to_max_len")
raise ValueError("Invalid pad_to_max_len")
if v and values["pad_to_max_len"]:
raise ValueError("pad_to_max_len and drop_short_sequences cannot both be True")
return v
@ -463,6 +463,7 @@ def text_gen_pair_pre_process(dataset, tokenizer, all_kwargs):
if tokenizer.pad_token_id is None:
raise ValueError("Tokenizer does not have a pad token")
# add padding to max_len
# pylint: disable=not-callable
input_ids = torch.nn.functional.pad(
input_ids, (0, max_len - input_ids.shape[0]), value=tokenizer.pad_token_id
)
@ -580,6 +581,7 @@ def format_pair_dataset(dataset, args):
}
if args.pair_format == TextGenPairFormat.ALPACA:
# pylint: disable=unnecessary-lambda
def extract_alpaca_dataset(example):
# extract new input from instruction and input

Просмотреть файл

@ -148,7 +148,7 @@ class DataConfig(ConfigBase):
# 3. Use value from component.params if already defined
# 4. Use the default value from the function signature
if param not in v.params:
if info.kind == info.VAR_POSITIONAL or info.kind == info.VAR_KEYWORD:
if info.kind in (info.VAR_POSITIONAL, info.VAR_KEYWORD):
continue
elif info.default is info.empty:
logger.debug(f"Missing parameter {param} for component {k}")

Просмотреть файл

@ -189,6 +189,8 @@ class Engine:
def initialize(self):
"""Initialize engine state. This should be done before running the registered passes."""
# pylint: disable=attribute-defined-outside-init
cache_dir = self._config.cache_dir
if self._config.clean_cache:
cache_utils.clean_cache(cache_dir)
@ -210,7 +212,7 @@ class Engine:
# so we check for both when determining the new model number
model_files = list(self._model_cache_path.glob("*_*"))
if len(model_files) > 0:
self._new_model_number = max([int(model_file.stem.split("_")[0]) for model_file in model_files]) + 1
self._new_model_number = max(int(model_file.stem.split("_")[0]) for model_file in model_files) + 1
# clean pass run cache if requested
# removes all run cache for pass type and all children elements
@ -361,7 +363,7 @@ class Engine:
if packaging_config and self.passes:
# TODO(trajep): should we support package input model?
# TODO(trajep): do you support packaging pytorch models?
logger.info(f"Package top ranked {sum([len(f.nodes) for f in outputs.values()])} models as artifacts")
logger.info(f"Package top ranked {sum(len(f.nodes) for f in outputs.values())} models as artifacts")
generate_output_artifacts(
packaging_config,
self.footprints,
@ -452,7 +454,7 @@ class Engine:
# These passes will be added to the search space
self.pass_flows_search_spaces = []
for pass_flow in self.pass_flows:
self.pass_search_spaces = []
self.pass_search_spaces = [] # pylint: disable=attribute-defined-outside-init
for pass_name in pass_flow:
p: Pass = self.passes[pass_name]["pass"]
self.pass_search_spaces.append((pass_name, p.search_space()))
@ -750,7 +752,7 @@ class Engine:
while True:
new_model_number = self._new_model_number
self._new_model_number += 1
if list(self._model_cache_path.glob(f"{new_model_number}_*")) == []:
if not list(self._model_cache_path.glob(f"{new_model_number}_*")):
break
return new_model_number
@ -895,7 +897,7 @@ class Engine:
if not should_prune:
# evaluate the model
evaluator_config = self.evaluator_for_pass(pass_id)
evaluator_config = self.evaluator_for_pass(pass_id) # pylint: disable=undefined-loop-variable
if self.no_search and evaluator_config is None:
# skip evaluation if no search and no evaluator
signal = None

Просмотреть файл

@ -33,7 +33,7 @@ def generate_output_artifacts(
pf_footprints: Dict[AcceleratorSpec, Footprint],
output_dir: Path,
):
if sum([len(f.nodes) if f.nodes else 0 for f in pf_footprints.values()]) == 0:
if sum(len(f.nodes) if f.nodes else 0 for f in pf_footprints.values()) == 0:
logger.warning("No model is selected. Skip packaging output artifacts.")
return
if packaging_config.type == PackagingType.Zipfile:
@ -101,9 +101,9 @@ def _package_candidate_models(
model_resource_path = create_resource_path(model_path) if model_path else None
model_type = pf_footprint.get_model_type(model_id)
if model_type == "ONNXModel":
with tempfile.TemporaryDirectory(dir=model_dir, prefix="olive_tmp") as tempdir:
# save to tempdir first since model_path may be a folder
temp_resource_path = create_resource_path(model_resource_path.save_to_dir(tempdir, "model", True))
with tempfile.TemporaryDirectory(dir=model_dir, prefix="olive_tmp") as model_tempdir:
# save to model_tempdir first since model_path may be a folder
temp_resource_path = create_resource_path(model_resource_path.save_to_dir(model_tempdir, "model", True))
# save to model_dir
if temp_resource_path.type == ResourceType.LocalFile:
# if model_path is a file, rename it to model_dir / model.onnx
@ -187,6 +187,7 @@ def _generate_onnx_mlflow_model(model_dir, inference_config):
def _package_onnxruntime_packages(tempdir, pf_footprint: Footprint):
# pylint: disable=not-an-iterable
installed_packages = pkg_resources.working_set
onnxruntime_pkg = [i for i in installed_packages if i.key.startswith("onnxruntime")]
ort_nightly_pkg = [i for i in installed_packages if i.key.startswith("ort-nightly")]

Просмотреть файл

@ -60,7 +60,7 @@ class AccuracyBase(AutoConfigClass):
continue
annotation = info.annotation if info.annotation != info.empty else None
default_value, required = (info.default, False) if info.default != info.empty else (None, True)
if info.kind == info.VAR_KEYWORD or info.kind == info.VAR_POSITIONAL:
if info.kind in (info.VAR_KEYWORD, info.VAR_POSITIONAL):
required = False
metric_config[param] = ConfigParam(type_=annotation, required=required, default_value=default_value)
return metric_config

Просмотреть файл

@ -80,12 +80,12 @@ class HuggingfaceMetrics(MetricBackend):
),
}
def measure_sub_metric(self, model_output, target, sub_metric: SubMetric) -> SubMetricResult:
def measure_sub_metric(self, model_output, targets, sub_metric: SubMetric) -> SubMetricResult:
load_params = sub_metric.metric_config.load_params or {}
evaluator = self.evaluate_module.load(sub_metric.name, **load_params)
compute_params = sub_metric.metric_config.compute_params or {}
result = evaluator.compute(predictions=model_output[0], references=target, **compute_params)
result = evaluator.compute(predictions=model_output[0], references=targets, **compute_params)
if not result:
raise ValueError(
f"Cannot find the result for {sub_metric.name} in the metric result. Please check your parameters."

Просмотреть файл

@ -40,6 +40,8 @@ from olive.snpe.data_loader import SNPECommonDataLoader, SNPEDataLoader
logger = logging.getLogger(__name__)
# pylint: disable=useless-parent-delegation
class OliveModelOutput(NamedTuple):
preds: Any
@ -479,8 +481,6 @@ class OnnxEvaluator(OliveEvaluator, framework=Framework.ONNX):
device: Device,
execution_providers: Union[str, List[str]],
) -> MetricResult:
from copy import deepcopy
from mpi4py.futures import MPIPoolExecutor
config = {
@ -573,8 +573,6 @@ class OnnxEvaluator(OliveEvaluator, framework=Framework.ONNX):
device,
execution_providers: Union[str, List[str]],
) -> MetricResult:
from copy import deepcopy
from mpi4py.futures import MPIPoolExecutor
config = {
@ -642,6 +640,7 @@ class OnnxEvaluator(OliveEvaluator, framework=Framework.ONNX):
@staticmethod
def disable_ort_fallback(session, execution_providers):
# pylint: disable=protected-access
if execution_providers:
assert isinstance(execution_providers, (str, list))
execution_providers = [execution_providers] if isinstance(execution_providers, str) else execution_providers
@ -728,6 +727,7 @@ class PyTorchEvaluator(OliveEvaluator, framework=Framework.PYTORCH):
device: Device = Device.CPU,
execution_providers: Union[str, List[str]] = None,
) -> MetricResult:
# pylint: disable=expression-not-assigned
warmup_num, repeat_test_num, _ = get_latency_config_from_metric(metric)
session = model.prepare_session(inference_settings=self.get_inference_settings(metric), device=device)

Просмотреть файл

@ -6,7 +6,7 @@ import logging
def set_verbosity(verbose):
logging.getLogger(__name__.split(".")[0]).setLevel(verbose)
logging.getLogger(__name__.split(".", maxsplit=1)[0]).setLevel(verbose)
def set_verbosity_info():

Просмотреть файл

@ -321,8 +321,8 @@ class ONNXModel(ONNXModelBase):
def prepare_session(
self,
inference_settings: Dict[str, Any],
device: Device,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.CPU,
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = None,
):
@ -571,49 +571,44 @@ class PyTorchModel(OliveModel):
def load_mlflow_model(self):
logger.info(f"Loading MLFlow model from {self.model_path}")
tmp_dir = tempfile.TemporaryDirectory(prefix="mlflow_tmp")
tmp_dir_path = Path(tmp_dir.name)
with tempfile.TemporaryDirectory(prefix="mlflow_tmp") as tmp_dir:
shutil.copytree(os.path.join(self.model_path, "data/model"), tmp_dir, dirs_exist_ok=True)
shutil.copytree(os.path.join(self.model_path, "data/config"), tmp_dir, dirs_exist_ok=True)
shutil.copytree(os.path.join(self.model_path, "data/tokenizer"), tmp_dir, dirs_exist_ok=True)
shutil.copytree(os.path.join(self.model_path, "data/model"), tmp_dir_path, dirs_exist_ok=True)
shutil.copytree(os.path.join(self.model_path, "data/config"), tmp_dir_path, dirs_exist_ok=True)
shutil.copytree(os.path.join(self.model_path, "data/tokenizer"), tmp_dir_path, dirs_exist_ok=True)
with open(os.path.join(self.model_path, "MLmodel")) as fp: # noqa: PTH123
mlflow_data = yaml.safe_load(fp)
# default flavor is "hftransformersv2" from azureml.evaluate.mlflow>=0.0.8
# "hftransformers" from azureml.evaluate.mlflow<0.0.8
# TODO(trajep): let user specify flavor name if needed
# to support other flavors in mlflow not only hftransformers
hf_pretrained_class = None
flavors = mlflow_data.get("flavors", {})
if not flavors:
raise ValueError(
"Invalid MLFlow model format. Please make sure the input model"
" format is same with the result of mlflow.transformers.save_model,"
" or aml_mlflow.hftransformers.save_model from azureml.evaluate.mlflow"
)
with open(os.path.join(self.model_path, "MLmodel")) as fp: # noqa: PTH123
mlflow_data = yaml.safe_load(fp)
# default flavor is "hftransformersv2" from azureml.evaluate.mlflow>=0.0.8
# "hftransformers" from azureml.evaluate.mlflow<0.0.8
# TODO(trajep): let user specify flavor name if needed
# to support other flavors in mlflow not only hftransformers
hf_pretrained_class = None
flavors = mlflow_data.get("flavors", {})
if not flavors:
raise ValueError(
"Invalid MLFlow model format. Please make sure the input model"
" format is same with the result of mlflow.transformers.save_model,"
" or aml_mlflow.hftransformers.save_model from azureml.evaluate.mlflow"
)
if "hftransformersv2" in flavors:
hf_pretrained_class = flavors["hftransformersv2"].get("hf_pretrained_class", "AutoModel")
elif "hftransformers" in flavors:
hf_pretrained_class = flavors["hftransformers"].get("hf_pretrained_class", "AutoModel")
else:
raise ValueError(
"Unsupported MLFlow model flavor. Currently only support hftransformersv2/hftransformers."
)
if "hftransformersv2" in flavors:
hf_pretrained_class = flavors["hftransformersv2"].get("hf_pretrained_class", "AutoModel")
elif "hftransformers" in flavors:
hf_pretrained_class = flavors["hftransformers"].get("hf_pretrained_class", "AutoModel")
else:
raise ValueError(
"Unsupported MLFlow model flavor. Currently only support hftransformersv2/hftransformers."
)
model_loader = huggingface_model_loader(hf_pretrained_class)
loaded_model = model_loader(tmp_dir_path)
loaded_model.eval()
tmp_dir.cleanup()
return loaded_model
model_loader = huggingface_model_loader(hf_pretrained_class)
loaded_model = model_loader(tmp_dir)
loaded_model.eval()
return loaded_model
def prepare_session(
self,
inference_settings: Dict[str, Any],
device: Device,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.CPU,
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = None,
):
@ -784,8 +779,8 @@ class SNPEModel(OliveModel):
def prepare_session(
self,
inference_settings: Dict[str, Any],
device: Device,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.CPU,
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = None,
) -> SNPEInferenceSession:
@ -824,8 +819,8 @@ class TensorFlowModel(OliveModel):
def prepare_session(
self,
inference_settings: Dict[str, Any],
device: Device,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.CPU,
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = None,
):
@ -876,8 +871,8 @@ class OpenVINOModel(OliveModel):
def prepare_session(
self,
inference_settings: Dict[str, Any],
device: Device,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.CPU,
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = None,
):
@ -922,19 +917,23 @@ class DistributedOnnxModel(ONNXModelBase):
def ranked_model_path(self, rank: int) -> Union[Path, str]:
return self.model_filepaths[rank]
def load_model(self, rank: int) -> ONNXModel:
def load_model(self, rank: int = None) -> ONNXModel:
return ONNXModel(self.model_filepaths[rank], inference_settings=self.inference_settings)
def prepare_session(
self,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.GPU,
device: Device = Device.GPU, # pylint: disable=signature-differs
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = 0,
):
raise RuntimeError("DistributedOnnxModel doesn't have a session of its own")
def get_default_execution_providers(self, filepath: str, device: Device):
def get_default_execution_providers(self, device: Device):
"""Return a list of supported default execution providers."""
return ["CPUExecutionProvider"]
def get_default_execution_providers_with_model(self, filepath: str, device: Device):
# return firstly available ep as ort default ep
available_providers = DistributedOnnxModel.get_execution_providers(device)
for ep in available_providers:
@ -999,8 +998,8 @@ class CompositeOnnxModel(ONNXModelBase):
def prepare_session(
self,
inference_settings: Dict[str, Any],
device: Device,
inference_settings: Optional[Dict[str, Any]] = None,
device: Device = Device.CPU,
execution_providers: Union[str, List[str]] = None,
rank: Optional[int] = None,
):

Просмотреть файл

@ -221,6 +221,9 @@ class HFConfig(ConfigBase):
model = load_huggingface_model_from_task(self.task, model_name_or_path, **loading_args)
elif self.model_class:
model = load_huggingface_model_from_model_class(self.model_class, model_name_or_path, **loading_args)
else:
raise ValueError("Either task or model_class must be specified")
return model
def load_model_config(self, model_path: str = None):
@ -317,15 +320,16 @@ def patched_supported_features_mapping(*supported_features: str, onnx_config_cls
def get_onnx_config(model_name: str, task: str, feature: Optional[str] = None):
# pylint: disable=protected-access
from transformers.onnx import FeaturesManager
from olive.model.hf_onnx_config import ADDITIONAL_MODEL_TYPES
# patch FeaturesManager._SUPPORTED_MODEL_TYPE to support additional models in olive
for model_type in ADDITIONAL_MODEL_TYPES:
for model_type, feature_list in ADDITIONAL_MODEL_TYPES.items():
if model_type in FeaturesManager._SUPPORTED_MODEL_TYPE:
continue
features, onnx_config_cls = ADDITIONAL_MODEL_TYPES[model_type]
features, onnx_config_cls = feature_list
FeaturesManager._SUPPORTED_MODEL_TYPE[model_type] = patched_supported_features_mapping(
*features, onnx_config_cls=onnx_config_cls
)

Просмотреть файл

@ -52,8 +52,8 @@ class IOConfig(ConfigBase):
return v
dynamic_axes = v
for k, v in dynamic_axes.items():
dynamic_axes[k] = {int(kk): vv for kk, vv in v.items()}
for k, value in dynamic_axes.items():
dynamic_axes[k] = {int(kk): vv for kk, vv in value.items()}
return dynamic_axes
@validator("string_to_int_dim_params")

Просмотреть файл

@ -398,8 +398,11 @@ class FullPassConfig(ConfigBase):
return v
def create_pass(self):
if not isinstance(self.accelerator, dict):
raise ValueError(f"accelerator must be a dict, got {self.accelerator}")
pass_cls = Pass.registry[self.type.lower()]
accelerator_spec = AcceleratorSpec(**self.accelerator)
accelerator_spec = AcceleratorSpec(**self.accelerator) # pylint: disable=not-a-mapping
return pass_cls(accelerator_spec, self.config, self.disable_search)

Просмотреть файл

@ -160,24 +160,22 @@ class OnnxConversion(Pass):
# there might be multiple files created during export, so we need to track the dir
# if there are other processes writing to the same dir, we might end up deleting files created by
# other processes
tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp")
tmp_dir_path = Path(tmp_dir.name)
tmp_model_path = str(tmp_dir_path / Path(output_model_path).name)
with tempfile.TemporaryDirectory(prefix="olive_tmp") as tmp_dir:
tmp_dir_path = Path(tmp_dir)
tmp_model_path = str(tmp_dir_path / Path(output_model_path).name)
torch.onnx.export(
pytorch_model,
dummy_inputs,
tmp_model_path,
export_params=True,
opset_version=config["target_opset"],
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
onnx_model = onnx.load(tmp_model_path)
# the model is loaded into memory, so it's safe to delete previously exported file(s)
tmp_dir.cleanup()
torch.onnx.export(
pytorch_model,
dummy_inputs,
tmp_model_path,
export_params=True,
opset_version=config["target_opset"],
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
onnx_model = onnx.load(tmp_model_path)
# the model is loaded into memory, so it's safe to delete previously exported file(s)
# Workaround as described under IOConfig.string_to_int_dim_params: change numeric dim_param to dim_value
if io_config.string_to_int_dim_params:

Просмотреть файл

@ -261,9 +261,6 @@ class IncQuantization(Pass):
_requires_user_script = True
def _initialize(self):
super()._initialize()
@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
@ -331,7 +328,7 @@ class IncQuantization(Pass):
# and return evaluation value.
# temporarily save model as onnx model
tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp")
tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp") # pylint: disable=consider-using-with
tmp_model_path = Path(tmp_dir.name) / "tmp_model.onnx"
# save as olive onnx model

Просмотреть файл

@ -43,9 +43,9 @@ class ModelOptimizer:
self.fuse_transpose_qat()
def fuse_transpose_qat(self):
for node_name in self.node_name2module:
node = self.node_name2module[node_name][0]
node_index = self.node_name2module[node_name][1]
for module in self.node_name2module.values():
node = module[0]
node_index = module[1]
if node.op_type == "Transpose":
if "DequantizeLinear" in node.input[0]:
dequant_node_name = node.input[0][:-9]
@ -188,12 +188,11 @@ class ModelOptimizer:
logger.debug(f"ModelOptimization: inserted node {cast_node.name}")
self.model = o_model.model
self.model = o_model.model # pylint: disable=attribute-defined-outside-init
def fuse_reshape_operations(self):
# Remove unnecessary Reshape operator. Consecutive Reshape operators with latter's input being "[-1]"
# i.e. flatten the input, the former Reshape operator is useless."""
import numpy as np
from onnxruntime.transformers.onnx_model import OnnxModel as TransformersOnnxModel
o_model = TransformersOnnxModel(self.model)
@ -215,7 +214,7 @@ class ModelOptimizer:
logger.debug(f"ModelOptimization: removed node {input_node_0.name}")
o_model.prune_graph()
self.model = o_model.model
self.model = o_model.model # pylint: disable=attribute-defined-outside-init
class OnnxModelOptimizer(Pass):

Просмотреть файл

@ -125,6 +125,7 @@ class MoEExpertDistributionPatternMatcherA(MoEExpertDistributionPatternMatcher):
]
def __init__(self, world_size: int, input_filepath: str, debug=False):
# pylint: disable=useless-parent-delegation
super().__init__(world_size, input_filepath, debug)
@staticmethod

Просмотреть файл

@ -87,16 +87,16 @@ def tune_onnx_model(model, data_root, config):
tuning_results = []
for tuning_combo in generate_tuning_combos(config):
tuning_item = ["provider", "execution_mode", "ort_opt_level", "io_bind"]
logger.info("Run tuning for: {}".format(list(zip(tuning_item, tuning_combo))))
logger.info("Run tuning for: %s", list(zip(tuning_item, tuning_combo)))
if not valid_config(tuning_combo, config):
continue
tuning_results.extend(threads_num_tuning(model, data_root, latency_metric, config, tuning_combo))
for tuning_result in tuning_results:
logger.debug("Tuning result: {}".format(tuning_result["latency_ms"]))
logger.debug("Tuning result: %s", tuning_result["latency_ms"])
best_result = parse_tuning_result(*tuning_results, pretuning_inference_result)
logger.info("Best result: {}".format(best_result))
logger.info("Best result: %s", best_result)
if best_result.get("test_name") != "pretuning":
optimized_model = copy.copy(model)
optimized_model.inference_settings = {
@ -152,8 +152,7 @@ def threads_num_tuning(model, data_root, latency_metric, config, tuning_combo):
test_params["session_options"]["intra_op_num_threads"] = intra
threads_num_binary_search(model, data_root, latency_metric, config, test_params, tuning_results)
except Exception:
logger.error("Optimization failed for tuning combo {}".format(tuning_combo), exc_info=True)
pass
logger.error("Optimization failed for tuning combo %s", tuning_combo, exc_info=True)
return tuning_results

Просмотреть файл

@ -6,6 +6,8 @@ import importlib as imp
from typing import Dict, List, Union
import onnx
# pylint: disable=wildcard-import
from onnxruntime_extensions.tools.pre_post_processing import * # noqa: F401, F403, RUF100
from onnxruntime_extensions.tools.pre_post_processing.utils import create_named_value
@ -178,7 +180,7 @@ def parse_step_params(model: onnx.ModelProto, step_config: Dict):
# Customized type definition
param_cls = get_customized_class(param_type)
params[param_name] = param_cls(**param_args)
elif param_type == "tuple" or param_type == "list":
elif param_type in ("tuple", "list"):
param_value = param_value.get("value")
# explicitly list or tuple type is specified

Просмотреть файл

@ -26,6 +26,8 @@ from olive.strategy.search_parameter import Boolean, Categorical, Conditional, C
logger = logging.getLogger(__name__)
# pylint: disable=consider-using-with
# common config for both static and dynamic quantization
_onnx_quantization_config = {
"weight_type": PassConfigParam(
@ -220,6 +222,7 @@ class OnnxQuantization(Pass):
def _initialize(self):
super()._initialize()
# pylint: disable=attribute-defined-outside-init
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp")
@staticmethod

Просмотреть файл

@ -23,6 +23,7 @@ from onnxruntime.quantization.quant_utils import QuantType
from olive.passes.onnx.vitis_ai.quant_utils import PowerOfTwoMethod, is_ort_version_below_1_16, quantize_data_pof2s
# pylint: skip-file
# ruff: noqa

Просмотреть файл

@ -11,6 +11,7 @@ from onnxruntime import __version__ as OrtVersion
from onnxruntime.quantization.quant_utils import get_qmin_qmax_for_qType, quantize_nparray
from packaging import version
# pylint: skip-file
# ruff: noqa

Просмотреть файл

@ -19,6 +19,7 @@ from olive.passes.onnx.vitis_ai.calibrate import PowerOfTwoMethod, create_calibr
from olive.passes.onnx.vitis_ai.quant_utils import get_exclude_nodes, is_ort_version_below_1_16
from olive.passes.onnx.vitis_ai.quantizer import VitisDPUQuantizer, VitisQDQQuantizer, VitisQOpQuantizer
# pylint: skip-file
# ruff: noqa

Просмотреть файл

@ -51,6 +51,7 @@ from olive.passes.onnx.vitis_ai.refine import adjust_quantize_info
logger = logging.getLogger(__name__)
# pylint: skip-file
# ruff: noqa
@ -125,7 +126,6 @@ class VitisDPUQuantizer(QDQQuantizer):
self.is_activation_symmetric = True
def vitis_quantize_initializer(self, weight, bit_width=8, keep_float_weight=False):
# Find if this input is already quantized
if weight.name in self.quantized_value_map:
quantized_value = self.quantized_value_map[weight.name]
@ -162,7 +162,6 @@ class VitisDPUQuantizer(QDQQuantizer):
return q_weight_name, zp_name, scale_name
def quantize_model(self):
self.tensor_info = {}
model = self.model.model
annotate_output_name_list = get_annotate_output_name(model)

Просмотреть файл

@ -12,6 +12,7 @@ refine_op_type = ["DequantizeLinear", "QuantizeLinear"]
postfix = "_Output"
logger = logging.getLogger(__name__)
# pylint: skip-file
# ruff: noqa
@ -235,7 +236,6 @@ class QuantPosManager(object):
shift_sigmoid = 14 + 'input pos' - ' output pos'
"""
for i, node in enumerate(self.model.model.graph.node):
if node.op_type not in ["Sigmoid"]:
continue
ipos_name = self.get_ipos_name(node)
@ -281,7 +281,6 @@ class QuantPosManager(object):
1. 0 <= shift_read <= 15
"""
for i, node in enumerate(self.model.model.graph.node):
if node.op_type not in ["Add"] or node.op_type not in ["Mul"]:
continue
ipos_layers = []
@ -339,7 +338,6 @@ class QuantPosManager(object):
1. -15 <= shift_write <= 15
"""
for i, node in enumerate(self.model.model.graph.node):
if node.op_type not in ["Add"] or node.op_type not in ["Mul"]:
continue
ipos_layers = []

Просмотреть файл

@ -22,6 +22,9 @@ from olive.strategy.search_parameter import Boolean, Categorical, Conditional
logger = logging.getLogger(__name__)
# pylint: disable=consider-using-with, attribute-defined-outside-init
# common config for Vitis-AI quantization
vai_q_onnx_quantization_config = {
"data_dir": PassConfigParam(

Просмотреть файл

@ -107,9 +107,9 @@ class OpenVINOQuantization(Pass):
metric = self._user_module_loader.load_object(config["metric_func"])
engine = IEEngine(config=config["engine_config"], data_loader=data_loader, metric=metric)
self.pipeline = create_pipeline(config["algorithms"], engine)
pipeline = create_pipeline(config["algorithms"], engine)
compressed_model = self.pipeline.run(model=model.load_model())
compressed_model = pipeline.run(model=model.load_model())
compress_model_weights(compressed_model)
compressed_model_paths = save_model(
model=compressed_model,

Просмотреть файл

@ -88,7 +88,7 @@ class PassConfigBase(ConfigBase):
finally:
if field.required and isinstance(v, PassParamDefault):
raise ValueError(f"{field.name} is required and cannot be set to {v.value}")
return v # noqa: B012
return v # noqa: B012 # pylint: disable=lost-exception, return-in-finally
@validator("*", pre=True)
def _validate_search_parameter(cls, v):

Просмотреть файл

@ -126,11 +126,11 @@ class BaseClusterEnvironment(ClusterEnvironment, ABC):
class AzureMLPerProcessCluster(BaseClusterEnvironment):
def _environment_variable_overrides(self, master_port: int = 6105) -> Dict[str, str]:
def _environment_variable_overrides(self, port: int = 6105) -> Dict[str, str]:
"""Set the MPI environment variables required for multinode distributed training.
Args:
master_port (int): Used to set MASTER_PORT environment variable if its not present.
port (int): Used to set MASTER_PORT environment variable if its not present.
"""
overrides = {}
@ -144,7 +144,7 @@ class AzureMLPerProcessCluster(BaseClusterEnvironment):
# Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
if "MASTER_PORT" not in os.environ:
overrides["MASTER_PORT"] = str(master_port)
overrides["MASTER_PORT"] = str(port)
else:
overrides["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
overrides["MASTER_PORT"] = "54965"

Просмотреть файл

@ -154,7 +154,7 @@ class QatTrainer:
if type(child) in white_list:
if type(child) not in skip_list:
new = QuantizedModule(child)
new.qconfig = qconfig
new.qconfig = qconfig # pylint: disable=attribute-defined-outside-init
setattr(module, name, new)
def _recursive_hasattr(self, obj, attribs, state=True):

Просмотреть файл

@ -34,8 +34,8 @@ def get_layers(model, model_type):
def get_layer_submodules(module, submodule_types=None, layer_name_filter=None, name=""):
submodule_types = submodule_types or [torch.nn.Conv2d, torch.nn.Linear, transformers.Conv1D]
"""Get the submodules of a module based on the submodule types."""
submodule_types = submodule_types or [torch.nn.Conv2d, torch.nn.Linear, transformers.Conv1D]
if type(module) in submodule_types:
if layer_name_filter and not any(s in name for s in layer_name_filter):
# skip this layer
@ -193,6 +193,7 @@ class SparseGPTModule:
self.H += batch_input.matmul(batch_input.t())
def prune(self, mode, sparsity=None, n=None, m=None, blocksize=128, percdamp=0.01):
# pylint: disable=not-callable
W = self.get_W()
H = self.H
del self.H
@ -204,7 +205,7 @@ class SparseGPTModule:
W[:, dead] = 0
# dampen the Hessian
assert percdamp >= 0 and percdamp <= 1
assert 0 <= percdamp <= 1
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.device)
H[diag, diag] += damp

Просмотреть файл

@ -3,7 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
try:
import torch_tensorrt # noqa: F401
import torch_tensorrt # noqa: F401 # pylint: disable=unused-import
except ImportError:
raise ImportError("Please install torch_tensorrt with: pip install torch-tensorrt") from None
@ -13,8 +13,8 @@ from contextlib import redirect_stdout
import tensorrt as trt
import torch
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule, compile
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule, compile # pylint: disable=redefined-builtin
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
class TRTLinearLayer(TRTModule):

Просмотреть файл

@ -297,7 +297,7 @@ class AzureMLResource(ResourcePath):
def get_path(self) -> str:
raise NotImplementedError
def save_to_dir(
def _save_to_dir(
self,
dir_path: Union[Path, str],
ml_client,
@ -363,7 +363,7 @@ class AzureMLModel(AzureMLResource):
def save_to_dir(self, dir_path: Union[Path, str], name: str = None, overwrite: bool = False) -> str:
ml_client = self.config.azureml_client.create_client()
return super().save_to_dir(dir_path, ml_client, self.config.azureml_client, overwrite, name)
return self._save_to_dir(dir_path, ml_client, self.config.azureml_client, overwrite, name)
class AzureMLRegistryModel(AzureMLResource):
@ -392,7 +392,7 @@ class AzureMLRegistryModel(AzureMLResource):
azureml_client_config = self.config.azureml_client or AzureMLClientConfig()
ml_client = azureml_client_config.create_registry_client(self.config.registry_name)
return super().save_to_dir(dir_path, ml_client, azureml_client_config, overwrite, name)
return self._save_to_dir(dir_path, ml_client, azureml_client_config, overwrite, name)
def _datastore_url_validator(v, values, **kwargs):

Просмотреть файл

@ -29,7 +29,7 @@ def dev():
logger.info("Done")
def eval(): # noqa: A001
def eval(): # noqa: A001 #pylint: disable=redefined-builtin
snpe_arch = get_snpe_target_arch(False)
if snpe_arch != "ARM64-Windows":
return

Просмотреть файл

@ -42,7 +42,7 @@ class SNPEDataLoader(ABC):
return
if self.tmp_dir is None:
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp_")
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp_") # pylint: disable=consider-using-with
self.batch_dir = str(Path(self.tmp_dir.name) / "batch")
batch_input_list = input_list_utils.resolve_input_list(
@ -107,6 +107,7 @@ class SNPEDataLoader(ABC):
return self.num_batches
def __iter__(self):
# pylint: disable=attribute-defined-outside-init
self.n = 0
return self
@ -137,7 +138,7 @@ class SNPEProcessedDataLoader(SNPEDataLoader):
super().__init__(config, batch_size)
def load_data(self) -> Tuple[str, str, np.ndarray]:
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp_")
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp_") # pylint: disable=consider-using-with
input_list = input_list_utils.get_input_list(
self.config["data_dir"], self.config["input_list_file"], self.tmp_dir.name
)
@ -177,7 +178,7 @@ class SNPERandomDataLoader(SNPEDataLoader):
super().__init__(config, batch_size)
def load_data(self) -> Tuple[str, str, np.ndarray]:
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp_")
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp_") # pylint: disable=consider-using-with
# get data_dir
if self.config["data_dir"] is None:
@ -241,20 +242,20 @@ class SNPECommonDataLoader(SNPEDataLoader):
# get single data sample
input_data, _ = next(iter(self.config["dataloader"]))
# source input names
for input_name in input_specs:
for input_name, input_spec in input_specs.items():
if input_name in input_data:
source_name = input_name
elif input_name.strip(":0") in input_data:
source_name = input_name.strip(":0")
else:
raise ValueError(f"Input name {input_name} not found in dataset")
input_specs[input_name]["source_name"] = source_name
input_spec["source_name"] = source_name
# source input_shapes and permutations
for input_name, input_spec in input_specs.items():
for input_spec in input_specs.values():
# get source shape
source_shape = list(input_data[input_spec["source_name"]].shape)
input_specs[input_name]["source_shape"] = source_shape
input_spec["source_shape"] = source_shape
# get permutation from source shape to target shape
target_shape = input_spec["target_shape"]
@ -283,10 +284,10 @@ class SNPECommonDataLoader(SNPEDataLoader):
f" shape {target_shape}"
)
input_specs[input_name]["permutation"] = permutation
input_spec["permutation"] = permutation
logger.debug(f"Input specs: {input_specs}")
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp_")
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp_") # pylint: disable=consider-using-with
data_dir = Path(self.tmp_dir.name) / "data"
data_dir.mkdir() # create data dir
@ -330,7 +331,7 @@ class SNPECommonDataLoader(SNPEDataLoader):
input_list_file = input_list_utils.create_input_list(
data_dir=str(data_dir),
input_names=list(input_specs.keys()),
input_dirs=[input_specs[input_name]["source_name"] for input_name in input_specs],
input_dirs=[input_spec["source_name"] for input_spec in input_specs.values()],
add_input_names=len(input_specs) > 1,
add_output_names=len(self.config["io_config"]["output_names"]) > 1,
output_names=self.config["io_config"]["output_names"],

Просмотреть файл

@ -38,7 +38,7 @@ def get_dlc_io_config(dlc_path: str, input_names: List[str], output_names: List[
input_names: list of input names of source model.
output_names: list of output names of source model.
"""
tmp_csv = tempfile.NamedTemporaryFile(suffix=".csv")
tmp_csv = tempfile.NamedTemporaryFile(suffix=".csv") # pylint: disable=consider-using-with
dlc_info = get_dlc_info(dlc_path, csv_path=tmp_csv.name)
# add the :0 suffix to the input/output names if present in the DLC
@ -172,7 +172,7 @@ def to_dlc(model_file: str, model_framework: str, config: dict, output_file: str
# check if conversion succeeded
if "Conversion completed successfully" not in stderr:
raise Exception(stderr)
raise Exception(stderr) # pylint: disable=broad-exception-raised
def quantize_dlc(dlc_path: str, input_list: str, config: dict, output_file: str):
@ -200,7 +200,7 @@ def quantize_dlc(dlc_path: str, input_list: str, config: dict, output_file: str)
# check if quantization succeeded
if not ("Writing quantized model" in stderr or "Saved quantized dlc" in stderr):
raise Exception(stderr)
raise Exception(stderr) # pylint: disable=broad-exception-raised
def dlc_to_onnx(

Просмотреть файл

@ -158,7 +158,7 @@ def snpe_net_run(
android_persist_ws: Whether to persist the workspace on android.
android_initialized: Whether the inference session has already been initialized on android using init_snpe_net_adb.
"""
tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp_")
tmp_dir = tempfile.TemporaryDirectory(prefix="olive_tmp_") # pylint: disable=consider-using-with
tmp_dir_path = Path(tmp_dir.name)
# Create the snpe-net-run command

Просмотреть файл

@ -130,8 +130,9 @@ class Conditional(SearchParameter):
def get_support(self, parent_values: Dict[str, Any]) -> Union[List[str], List[int], List[float], List[bool]]:
"""Get the support for the search parameter for a given parent value."""
# pylint: disable=arguments-differ
assert parent_values.keys() == set(self.parents), "parent values keys do not match the parents"
parent_values = tuple([parent_values[parent] for parent in self.parents])
parent_values = tuple(parent_values[parent] for parent in self.parents)
return self.support.get(parent_values, self.default).get_support()
def condition(self, parent_values: Dict[str, Any]) -> SearchParameter:
@ -163,7 +164,7 @@ class Conditional(SearchParameter):
new_conditional = Conditional(new_parents, new_support, self.default)
# condition the new conditional if there are more parents to condition, else return the new conditional
del parent_values[parent]
del parent_values[parent] # pylint: disable=undefined-loop-variable
if len(parent_values) == 0:
return new_conditional
return new_conditional.condition(parent_values)
@ -285,7 +286,7 @@ def json_to_search_parameter(json: Dict[str, Any]) -> SearchParameter:
search_parameter_type = json["type"]
if search_parameter_type == "Categorical":
return Categorical(json["support"])
if search_parameter_type == "Conditional" or search_parameter_type == "ConditionalDefault":
if search_parameter_type in ("Conditional", "ConditionalDefault"):
def stop_condition(x):
return isinstance(x, dict) and x.get("olive_parameter_type") == "SearchParameter"

Просмотреть файл

@ -49,7 +49,7 @@ class SearchResults:
def check_goals(self, result: MetricResult) -> bool:
"""Check if the result satisfies the constraints."""
# if goals are not set, return True always
if self.goals == {}:
if not self.goals:
return True
for obj, goal in self.goals.items():

Просмотреть файл

@ -18,6 +18,8 @@ logger = logging.getLogger(__name__)
_VALID_EXECUTION_ORDERS = ["joint", "pass-by-pass"]
# pylint: disable=attribute-defined-outside-init
class SearchStrategyConfig(ConfigBase):
execution_order: str

Просмотреть файл

@ -76,7 +76,7 @@ def main(raw_args=None):
# original directory for model path is read only, so we need to copy the model to a temp directory
input_model_path = input_model_config["config"].get("model_path")
if input_model_path is not None:
tmp_dir = tempfile.TemporaryDirectory()
tmp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
old_path = Path(input_model_path).resolve()
new_path = Path(tmp_dir.name).resolve() / old_path.name
if old_path.is_file():

Просмотреть файл

@ -25,6 +25,8 @@ from olive.systems.olive_system import OliveSystem
logger = logging.getLogger(__name__)
# pylint: disable=c-extension-no-member
class DockerSystem(OliveSystem):
system_type = SystemType.Docker
@ -98,7 +100,7 @@ class DockerSystem(OliveSystem):
def run_pass(
self,
the_pass: Pass,
model: ModelConfig,
model_config: ModelConfig,
data_root: str,
output_model_path: str,
point: Optional[Dict[str, Any]] = None,
@ -138,7 +140,7 @@ class DockerSystem(OliveSystem):
volumes_list.append(eval_file_mount_str)
if self.is_dev:
dev_mount_path, dev_mount_str = docker_utils.create_dev_mount(tempdir, container_root_path)
_, dev_mount_str = docker_utils.create_dev_mount(tempdir, container_root_path)
volumes_list.append(dev_mount_str)
model_config_copy = copy.deepcopy(model_config)

Просмотреть файл

@ -13,6 +13,7 @@ import numpy as np
ort_inference_utils_parent = Path(__file__).resolve().parent.parent.parent / "common"
sys.path.append(str(ort_inference_utils_parent))
# pylint: disable=wrong-import-position
from ort_inference import get_ort_inference_session # noqa: E402

Просмотреть файл

@ -10,6 +10,7 @@ from pathlib import Path
ort_inference_utils_parent = Path(__file__).resolve().parent.parent.parent / "common"
sys.path.append(str(ort_inference_utils_parent))
# pylint: disable=wrong-import-position
from ort_inference import get_ort_inference_session # noqa: E402

Просмотреть файл

@ -70,7 +70,7 @@ class PythonEnvironmentSystem(OliveSystem):
os.makedirs(temp_dir)
self.environ["TMPDIR"] = temp_dir
else:
self.environ["TMPDIR"] = tempfile.TemporaryDirectory().name
self.environ["TMPDIR"] = tempfile.TemporaryDirectory().name # pylint: disable=consider-using-with
# available eps. This will be populated the first time self.get_supported_execution_providers() is called.
# used for caching the available eps

Просмотреть файл

@ -69,6 +69,7 @@ def create_new_system_with_cache(origin_system, accelerator):
def create_new_system(origin_system, accelerator):
# pylint: disable=consider-using-with
provider_dockerfile_mapping = {
"CPUExecutionProvider": "Dockerfile.cpu",
"CUDAExecutionProvider": "Dockerfile.gpu",

Просмотреть файл

@ -59,7 +59,7 @@ def convertquantize(
snpe_conversion = create_pass_from_dict(
SNPEConversion, {**config["io_config"], **config["convert_options"]}, disable_search=True
)
snpe_model = snpe_conversion.run(model, snpe_model_file)
snpe_model = snpe_conversion.run(model, None, snpe_model_file)
assert Path(snpe_model.model_path).is_file()
with (models_dir / f"{name}.dlc_io_config.json").open("w") as f:
json.dump(snpe_model.io_config, f)
@ -68,14 +68,16 @@ def convertquantize(
# SNPE Quantized model
logger.info("Quantizing SNPE model...")
snpe_quantized_model_file = str(models_dir / f"{name}.quant.dlc")
dataloader_func = lambda data_dir: SNPEProcessedDataLoader(data_dir, input_list_file=input_list_file) # noqa: E731
def dataloader_func(data_dir):
return SNPEProcessedDataLoader(data_dir, input_list_file=input_list_file)
snpe_quantization = create_pass_from_dict(
SNPEQuantization,
{"data_dir": str(data_dir), "dataloader_func": dataloader_func, **config["quantize_options"]},
disable_search=True,
)
snpe_quantized_model = snpe_quantization.run(snpe_model, snpe_quantized_model_file)
snpe_quantized_model = snpe_quantization.run(snpe_model, None, snpe_quantized_model_file)
assert Path(snpe_quantized_model.model_path).is_file()
with (models_dir / f"{name}.quant.dlc_io_config.json").open("w") as f:
json.dump(snpe_quantized_model.io_config, f)
@ -94,5 +96,5 @@ def convertquantize(
{"target_device": config["quantize_options"].get("target_device", SNPEDevice.CPU)},
disable_search=True,
)
snpe_quantized_onnx_model = snpe_to_onnx_conversion.run(snpe_quantized_model, snpe_quantized_onnx_model_file)
snpe_quantized_onnx_model = snpe_to_onnx_conversion.run(snpe_quantized_model, None, snpe_quantized_onnx_model_file)
assert Path(snpe_quantized_onnx_model.model_path).is_file()

Просмотреть файл

@ -39,16 +39,39 @@ good-names = [
[tool.pylint.messages_control]
disable = [
# TODO:(myguo): consider remove them in the blacklist
"broad-exception-caught",
"cyclic-import", # Disable cyclic-import because it is pylint bug
"consider-using-f-string",
"consider-using-from-import",
"format",
"expression-not-assigned",
"line-too-long",
"import-error",
"import-outside-toplevel",
"invalid-name",
"logging-format-interpolation",
"logging-fstring-interpolation",
"no-else-continue",
"no-else-raise",
"no-else-return",
"no-name-in-module",
"no-member",
"too-many-arguments",
"too-many-locals",
"no-self-argument",
"too-few-public-methods",
"too-many-arguments",
"too-many-branches",
"too-many-function-args",
"too-many-instance-attributes",
"too-many-locals",
"too-many-nested-blocks",
"too-many-public-methods",
"too-many-return-statements",
"too-many-statements",
"missing-docstring",
"fixme"
"fixme",
"unspecified-encoding",
"unused-argument"
]
[tool.ruff]

Просмотреть файл

@ -1,6 +1,10 @@
-r requirements.txt
black
editorconfig-checker
flake8
isort
lintrunner
lintrunner-adapters
pre-commit
pylint
ruff==0.1.0

Просмотреть файл

@ -1,9 +0,0 @@
# BLACK-ISORT
black==23.7.0
isort==5.12.0
# This file is auto updated by dependabot
lintrunner-adapters>=0.8.0
# PYLINT
pylint==2.17.2
# RUFF, RUFF-FIX
ruff==0.0.291

Просмотреть файл

@ -64,7 +64,7 @@ long_description = (
" taking a set of constraints such as accuracy and latency into consideration."
)
description = long_description.split(".")[0] + "."
description = long_description.split(".", maxsplit=1)[0] + "."
setup(
name="olive-ai",

Просмотреть файл

@ -10,6 +10,8 @@ import pytest
from olive.resource_path import ResourceType, create_resource_path
# pylint: disable=attribute-defined-outside-init, consider-using-with
class TestAMLResourcePath:
@pytest.fixture(autouse=True)

Просмотреть файл

@ -13,6 +13,8 @@ from olive.azureml.azureml_client import AzureMLClientConfig
from olive.evaluator.metric import AccuracySubType, LatencySubType, Metric, MetricType
from olive.systems.azureml import AzureMLDockerConfig, AzureMLSystem
# pylint: disable=redefined-outer-name
def get_directories():
current_dir = Path(__file__).resolve().parent

Просмотреть файл

@ -13,6 +13,8 @@ from torchvision.transforms import ToTensor
from olive.evaluator.metric import AccuracySubType, LatencySubType, Metric, MetricType
from olive.systems.docker import DockerSystem, LocalDockerConfig
# pylint: disable=redefined-outer-name
def get_directories():
current_dir = Path(__file__).resolve().parent

Просмотреть файл

@ -25,6 +25,8 @@ from olive.hardware import DEFAULT_CPU_ACCELERATOR
from olive.model import ModelConfig
from olive.systems.local import LocalSystem
# pylint: disable=redefined-builtin
class TestLocalEvaluation:
@pytest.fixture(scope="class", autouse=True)

Просмотреть файл

@ -13,6 +13,8 @@ from torchvision.transforms import ToTensor
from olive.evaluator.metric import AccuracySubType, LatencySubType, Metric, MetricType
# pylint: disable=redefined-outer-name
def get_directories():
current_dir = Path(__file__).resolve().parent

Просмотреть файл

@ -6,6 +6,8 @@ import os
from azure.storage.blob import BlobClient
# pylint: disable=broad-exception-raised
def get_olive_workspace_config():
subscription_id = os.environ.get("WORKSPACE_SUBSCRIPTION_ID")

Просмотреть файл

@ -15,6 +15,8 @@ from olive.hardware.accelerator import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec
from olive.model import ModelConfig
from olive.passes.onnx import OrtPerfTuning
# pylint: disable=attribute-defined-outside-init, consider-using-with
@pytest.mark.skipif(platform.system() == "Windows", reason="Skip test on Windows.")
class TestOliveAzureMLSystem:

Просмотреть файл

@ -14,6 +14,8 @@ from olive.hardware.accelerator import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec
from olive.model import ModelConfig
from olive.passes.onnx import OrtPerfTuning
# pylint: disable=attribute-defined-outside-init, consider-using-with
@pytest.mark.skipif(platform.system() == "Windows", reason="Docker target does not support windows")
class TestOliveManagedDockerSystem:

Просмотреть файл

@ -16,6 +16,8 @@ from olive.hardware.accelerator import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec
from olive.passes.onnx import OrtPerfTuning
from olive.systems.python_environment import PythonEnvironmentSystem
# pylint: disable=attribute-defined-outside-init, consider-using-with
class TestOliveManagedPythonEnvironmentSystem:
@pytest.fixture(autouse=True)

Просмотреть файл

@ -10,6 +10,8 @@ from torchvision.transforms import ToTensor
from olive.evaluator.metric import LatencySubType, Metric, MetricType
# pylint: disable=redefined-outer-name
def get_directories():
current_dir = Path(__file__).resolve().parent

Просмотреть файл

@ -6,6 +6,8 @@ import pytest
from olive.common.utils import retry_func
# pylint: disable=global-variable-undefined, used-before-assignment
def fail_with_key_error():
global num_tries

Просмотреть файл

@ -10,6 +10,8 @@ import pytest
from olive.data.config import DataConfig
from olive.data.registry import Registry
# pylint: disable=attribute-defined-outside-init
class TestDataConfig:
@pytest.fixture(autouse=True)

Просмотреть файл

@ -16,6 +16,8 @@ import pytest
from olive.data.config import DataConfig
from olive.data.container.data_container import DataContainer
# pylint: disable=attribute-defined-outside-init
class TestDataConfig:
@pytest.fixture(autouse=True)

Просмотреть файл

@ -21,6 +21,8 @@ from olive.evaluator.olive_evaluator import OliveEvaluatorConfig
from olive.hardware import DEFAULT_CPU_ACCELERATOR
from olive.passes.onnx.conversion import OnnxConversion
# pylint: disable=consider-using-with
@patch("onnx.external_data_helper.sys.getsizeof")
@pytest.mark.parametrize(

Просмотреть файл

@ -25,6 +25,8 @@ from olive.passes.onnx import OnnxConversion, OnnxDynamicQuantization, OnnxStati
from olive.systems.common import SystemType
from olive.systems.local import LocalSystem
# pylint: disable=consider-using-with, protected-access
# Please note your test case could still "pass" even if it throws exception to fail.
# Please check log message to make sure your test case passes.

Просмотреть файл

@ -10,6 +10,8 @@ import pytest
from olive.engine.footprint import Footprint
# pylint: disable=attribute-defined-outside-init
class TestFootprint:
@pytest.fixture(autouse=True)

Просмотреть файл

@ -13,6 +13,8 @@ from olive.evaluator.metric_backend import HuggingfaceMetrics
from olive.hardware import DEFAULT_CPU_ACCELERATOR
from olive.systems.local import LocalSystem
# pylint: disable=attribute-defined-outside-init, redefined-outer-name
class TestMetricBackend:
@pytest.fixture(autouse=True)
@ -59,8 +61,6 @@ class TestMetricBackend:
HF_ACCURACY_TEST_CASE,
)
def test_evaluate_backend(self, model_config, metric, expected_res):
from olive.evaluator.metric_backend import HuggingfaceMetrics
with patch.object(HuggingfaceMetrics, "measure_sub_metric") as mock_measure:
mock_measure.return_value = SubMetricResult(value=expected_res, higher_is_better=True, priority=-1)
system = LocalSystem()

Просмотреть файл

@ -24,7 +24,6 @@ from olive.systems.local import LocalSystem
class TestOliveEvaluator:
ACCURACY_TEST_CASE: ClassVar[list] = [
(
PyTorchEvaluator(),
@ -184,7 +183,7 @@ class TestDistributedOnnxEvaluator:
from olive.model import DistributedOnnxModel
filepaths = ["examples/switch/model_4n_2l_8e_00.onnx", "examples/switch/model_4n_2l_8e_01.onnx"]
model = DistributedOnnxModel(filepaths, name="model_4n_2l_8e")
model = DistributedOnnxModel(filepaths)
user_config = {
"user_script": "examples/switch/user_script.py",

Просмотреть файл

@ -18,6 +18,8 @@ from azureml.evaluate import mlflow as aml_mlflow
from olive.model import PyTorchModel
# pylint: disable=attribute-defined-outside-init, consider-using-with
class TestPyTorchMLflowModel(unittest.TestCase):
def setup(self):

Просмотреть файл

@ -10,6 +10,8 @@ from pydantic import ValidationError
from olive.hardware import DEFAULT_CPU_ACCELERATOR
from olive.passes.onnx import OrtPerfTuning
# pylint: disable=consider-using-with
class TestUserScriptConfig:
def test_no_config(self):

Просмотреть файл

@ -149,7 +149,6 @@ def get_superresolution_model():
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
return self.pixel_shuffle(self.conv4(x))
return x
def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain("relu"))

Просмотреть файл

@ -12,6 +12,8 @@ from olive.hardware import DEFAULT_CPU_ACCELERATOR, DEFAULT_GPU_CUDA_ACCELERATOR
from olive.passes.onnx import OrtTransformersOptimization
from olive.passes.onnx.common import get_external_data_config
# pylint: disable=redefined-outer-name, abstract-method, protected-access
def test_fusion_options():
config = {"model_type": "bart", "optimization_options": {"use_multi_head_attention": True}}

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше