* add_pylint_to_workflow

* fix-pylint

* fix_pylinterror

* fix-issue
This commit is contained in:
SunsetWolf 2022-01-26 19:27:24 +08:00 коммит произвёл GitHub
Родитель 635632e4ed
Коммит 144e1e2459
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
103 изменённых файлов: 318 добавлений и 387 удалений

32
.github/workflows/test.yml поставляемый
Просмотреть файл

@ -33,7 +33,37 @@ jobs:
- name: Install Qlib with pip
run: |
pip install numpy==1.19.5 ruamel.yaml
pip install pyqlib --ignore-installed
pip install pyqlib --ignore-installed
# Check Qlib with pylint
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
# C0103: invalid-name
# C0209: consider-using-f-string
# R0402: consider-using-from-import
# R1705: no-else-return
# R1710: inconsistent-return-statements
# R1725: super-with-arguments
# R1735: use-dict-literal
# W0102: dangerous-default-value
# W0212: protected-access
# W0221: arguments-differ
# W0223: abstract-method
# W0231: super-init-not-called
# W0237: arguments-renamed
# W0612: unused-variable
# W0621: redefined-outer-name
# W0622: redefined-builtin
# FIXME: specify exception type
# W0703: broad-except
# W1309: f-string-without-interpolation
# E1102: not-callable
# E1136: unsubscriptable-object
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
- name: Check Qlib with pylint
run: |
pip install --upgrade pip
pip install pylint
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0201,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
- name: Test data downloads
run: |

Просмотреть файл

@ -30,8 +30,8 @@ def init(default_conf="client", **kwargs):
When using the recorder, skip_if_reg can set to True to avoid loss of recorder.
"""
from .config import C
from .data.cache import H
from .config import C # pylint: disable=C0415
from .data.cache import H # pylint: disable=C0415
# FIXME: this logger ignored the level in config
logger = get_module_logger("Initialization", level=logging.INFO)
@ -85,7 +85,7 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
# If the provider uri looks like this 172.23.233.89//data/csdesign'
# It will be a nfs path. The client provider will be used
if not auto_mount:
if not auto_mount: # pylint: disable=R1702
if not Path(mount_path).exists():
raise FileNotFoundError(
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
@ -139,8 +139,10 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
if not _is_mount:
try:
Path(mount_path).mkdir(parents=True, exist_ok=True)
except Exception:
raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
except Exception as e:
raise OSError(
f"Failed to create directory {mount_path}, please create {mount_path} manually!"
) from e
# check nfs-common
command_res = os.popen("dpkg -l | grep nfs-common")

Просмотреть файл

@ -171,8 +171,8 @@ def get_strategy_executor(
# NOTE:
# - for avoiding recursive import
# - typing annotations is not reliable
from ..strategy.base import BaseStrategy
from .executor import BaseExecutor
from ..strategy.base import BaseStrategy # pylint: disable=C0415
from .executor import BaseExecutor # pylint: disable=C0415
trade_account = create_account_instance(
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type

Просмотреть файл

@ -2,11 +2,11 @@
# Licensed under the MIT License.
from __future__ import annotations
import copy
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple
from qlib.utils import init_instance_by_config
import pandas as pd
from .position import BasePosition, InfPosition, Position
from .position import BasePosition
from .report import PortfolioMetrics, Indicator
from .decision import BaseTradeDecision, Order
from .exchange import Exchange

Просмотреть файл

@ -7,19 +7,18 @@ from qlib.data.data import Cal
from qlib.utils.time import concat_date_time, epsilon_change
from qlib.log import get_module_logger
from typing import ClassVar, Optional, Union, List, Tuple
# try to fix circular imports when enabling type hints
from typing import Callable, TYPE_CHECKING
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from qlib.strategy.base import BaseStrategy
from qlib.backtest.exchange import Exchange
from qlib.backtest.utils import TradeCalendarManager
import warnings
import numpy as np
import pandas as pd
import numpy as np
from dataclasses import dataclass, field
from typing import ClassVar, Optional, Union, List, Set, Tuple
from dataclasses import dataclass
class OrderDir(IntEnum):
@ -418,7 +417,7 @@ class BaseTradeDecision:
return kwargs["default_value"]
else:
# Default to get full index
raise NotImplementedError(f"The decision didn't provide an index range")
raise NotImplementedError(f"The decision didn't provide an index range") from NotImplementedError
# clip index
if getattr(self, "total_step", None) is not None:

Просмотреть файл

@ -3,13 +3,13 @@
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import List, Tuple, Union
if TYPE_CHECKING:
from .account import Account
from qlib.backtest.position import BasePosition, Position
import random
from typing import List, Tuple, Union
import numpy as np
import pandas as pd
@ -18,7 +18,7 @@ from ..config import C
from ..constant import REG_CN
from ..log import get_module_logger
from .decision import Order, OrderDir, OrderHelper
from .high_performance_ds import BaseQuote, PandasQuote, NumpyQuote
from .high_performance_ds import BaseQuote, NumpyQuote
class Exchange:

Просмотреть файл

@ -1,22 +1,18 @@
from abc import abstractclassmethod, abstractmethod
from abc import abstractmethod
import copy
from qlib.backtest.position import BasePosition
from qlib.log import get_module_logger
from types import GeneratorType
from qlib.backtest.account import Account
import warnings
import pandas as pd
from typing import List, Tuple, Union
from collections import defaultdict
from qlib.backtest.report import Indicator
from .decision import EmptyTradeDecision, Order, BaseTradeDecision
from .decision import Order, BaseTradeDecision
from .exchange import Exchange
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
from ..utils import init_instance_by_config
from ..utils.time import Freq
from ..strategy.base import BaseStrategy
@ -193,7 +189,8 @@ class BaseExecutor:
pass
return return_value.get("execute_result")
@abstractclassmethod
@classmethod
@abstractmethod
def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
"""
Please refer to the doc of collect_data
@ -453,7 +450,6 @@ class NestedExecutor(BaseExecutor):
inner_exe_res :
the execution result of inner task
"""
pass
def get_all_executors(self):
"""get all executors, including self and inner_executor.get_all_executors()"""

Просмотреть файл

@ -2,8 +2,6 @@
# Licensed under the MIT License.
import copy
import pathlib
from typing import Dict, List, Union
import pandas as pd
@ -538,7 +536,7 @@ class InfPosition(BasePosition):
def get_stock_amount_dict(self) -> Dict:
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
def get_stock_weight_dict(self, only_stock: bool) -> Dict:
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
def add_count_all(self, bar):

Просмотреть файл

@ -10,11 +10,8 @@ import numpy as np
import pandas as pd
from qlib.backtest.exchange import Exchange
from .decision import IdxTradeRange
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir
from qlib.backtest.utils import TradeCalendarManager
from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
from ..data import D
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
from ..tests.config import CSI300_BENCH
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
import qlib.utils.index_data as idd

Просмотреть файл

@ -388,13 +388,11 @@ class QlibConfig(Config):
default_conf : str
the default config template chosen by user: "server", "client"
"""
from .utils import set_log_with_config, get_module_logger, can_use_cache
from .utils import set_log_with_config, get_module_logger, can_use_cache # pylint: disable=C0415
self.reset()
_logging_config = self.logging_config
if "logging_config" in kwargs:
_logging_config = kwargs["logging_config"]
_logging_config = kwargs.get("logging_config", self.logging_config)
# set global config
if _logging_config:
@ -433,11 +431,11 @@ class QlibConfig(Config):
)
def register(self):
from .utils import init_instance_by_config
from .data.ops import register_all_ops
from .data.data import register_all_wrappers
from .workflow import R, QlibRecorder
from .workflow.utils import experiment_exit_handler
from .utils import init_instance_by_config # pylint: disable=C0415
from .data.ops import register_all_ops # pylint: disable=C0415
from .data.data import register_all_wrappers # pylint: disable=C0415
from .workflow import R, QlibRecorder # pylint: disable=C0415
from .workflow.utils import experiment_exit_handler # pylint: disable=C0415
register_all_ops(self)
register_all_wrappers(self)
@ -454,7 +452,7 @@ class QlibConfig(Config):
self._registered = True
def reset_qlib_version(self):
import qlib
import qlib # pylint: disable=C0415
reset_version = self.get("qlib_reset_version", None)
if reset_version is not None:

Просмотреть файл

@ -7,8 +7,7 @@ import warnings
import numpy as np
import pandas as pd
from qlib.utils import init_instance_by_config
from qlib.data.dataset import DatasetH, DataHandler
from qlib.data.dataset import DatasetH
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -16,7 +15,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
def _to_tensor(x):
if not isinstance(x, torch.Tensor):
return torch.tensor(x, dtype=torch.float, device=device)
return torch.tensor(x, dtype=torch.float, device=device) # pylint: disable=E1101
return x

Просмотреть файл

@ -5,9 +5,7 @@ from ...data.dataset.handler import DataHandlerLP
from ...data.dataset.processor import Processor
from ...utils import get_callable_kwargs
from ...data.dataset import processor as processor_module
from ...log import TimeInspector
from inspect import getfullargspec
import copy
def check_transform_proc(proc_l, fit_start_time, fit_end_time):

Просмотреть файл

@ -1,9 +1,6 @@
import numpy as np
import pandas as pd
import copy
from ...log import TimeInspector
from ...utils.serial import Serializable
from ...data.dataset.processor import Processor, get_group_columns

Просмотреть файл

@ -5,12 +5,10 @@
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
import pandas as pd
from scipy.stats import spearmanr, pearsonr
from ..data import D
from collections import OrderedDict
@ -243,4 +241,4 @@ def get_rank_ic(a, b):
def get_normal_ic(a, b):
return pearsonr(a, b).correlation
return pearsonr(a, b)[0]

Просмотреть файл

@ -1,24 +1,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from copy import deepcopy
from qlib.data.dataset.utils import init_task_handler
from qlib.utils.data import deepcopy_basic_type
from qlib.contrib.torch import data_to_tensor
from qlib.workflow.task.utils import TimeAdjuster
from qlib.model.meta.task import MetaTask
from typing import Dict, List, Union, Text, Tuple
from qlib.data.dataset.handler import DataHandler
from qlib.log import get_module_logger
from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from joblib import Parallel, delayed
from qlib.model.meta.dataset import MetaTaskDataset
from qlib.model.trainer import task_train, TrainerR
from qlib.data.dataset import DatasetH
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
from copy import deepcopy
from joblib import Parallel, delayed # pylint: disable=E0401
from typing import Dict, List, Union, Text, Tuple
from qlib.data.dataset.utils import init_task_handler
from qlib.data.dataset import DatasetH
from qlib.contrib.torch import data_to_tensor
from qlib.model.meta.task import MetaTask
from qlib.model.meta.dataset import MetaTaskDataset
from qlib.model.trainer import TrainerR
from qlib.log import get_module_logger
from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config
from qlib.utils.data import deepcopy_basic_type
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.utils import TimeAdjuster
from tqdm.auto import tqdm
class InternalData:

Просмотреть файл

@ -1,28 +1,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.log import get_module_logger
import pandas as pd
import numpy as np
from qlib.model.meta.task import MetaTask
import torch
from torch import nn
from torch import optim
from tqdm.auto import tqdm
import collections
import copy
from typing import Union, List, Tuple, Dict
from typing import Union, List
from ....data.dataset.weight import Reweighter
from ....model.meta.dataset import MetaTaskDataset
from ....model.meta.model import MetaModel, MetaTaskModel
from ....model.meta.model import MetaTaskModel
from ....workflow import R
from .utils import ICLoss
from .dataset import MetaDatasetDS
from qlib.contrib.meta.data_selection.net import PredNet
from qlib.data.dataset.weight import Reweighter
from qlib.log import get_module_logger
from qlib.data.dataset.weight import Reweighter
from qlib.model.meta.task import MetaTask
from qlib.contrib.meta.data_selection.net import PredNet
logger = get_module_logger("data selection")

Просмотреть файл

@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
import numpy as np
import torch
from torch import nn

Просмотреть файл

@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
import numpy as np
import torch
from torch import nn
from qlib.contrib.torch import data_to_tensor
class ICLoss(nn.Module):

Просмотреть файл

@ -101,7 +101,7 @@ class LGBModel(ModelFT, LightGBMFInt):
verbose level
"""
# Based on existing model and finetune by train more rounds
dtrain, _ = self._prepare_data(dataset, reweighter)
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
if dtrain.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
self.model = lgb.train(

Просмотреть файл

@ -58,7 +58,7 @@ class HFLGBModel(ModelFT, LightGBMFInt):
"""
Test the signal in high frequency test set
"""
if self.model == None:
if self.model is None:
raise ValueError("Model hasn't been trained yet")
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
df_test.dropna(inplace=True)

Просмотреть файл

@ -1,12 +1,10 @@
# Copyright (c) Microsoft Corporation.
import os
from pdb import set_trace
from torch.utils.data import Dataset, DataLoader
import copy
from typing import Text, Union
import math
import numpy as np
import pandas as pd
import torch
@ -182,11 +180,11 @@ class ADARNN(Model):
continue
total_loss = torch.zeros(1).cuda()
for i in range(len(index)):
feature_s = list_feat[index[i][0]]
feature_t = list_feat[index[i][1]]
label_reg_s = list_label[index[i][0]]
label_reg_t = list_label[index[i][1]]
for i, n in enumerate(index):
feature_s = list_feat[n[0]]
feature_t = list_feat[n[1]]
label_reg_s = list_label[n[0]]
label_reg_t = list_label[n[1]]
feature_all = torch.cat((feature_s, feature_t), 0)
if epoch < self.pre_epoch:
@ -410,7 +408,7 @@ class AdaRNN(nn.Module):
in_size = hidden
self.features = nn.Sequential(*features)
if use_bottleneck == True: # finance
if use_bottleneck is True: # finance
self.bottleneck = nn.Sequential(
nn.Linear(n_hiddens[-1], bottleneck_width),
nn.Linear(bottleneck_width, bottleneck_width),
@ -449,7 +447,7 @@ class AdaRNN(nn.Module):
def forward_pre_train(self, x, len_win=0):
out = self.gru_features(x)
fea = out[0] # [2N,L,H]
if self.use_bottleneck == True:
if self.use_bottleneck is True:
fea_bottleneck = self.bottleneck(fea[:, -1, :])
fc_out = self.fc(fea_bottleneck).squeeze()
else:
@ -458,8 +456,8 @@ class AdaRNN(nn.Module):
out_list_all, out_weight_list = out[1], out[2]
out_list_s, out_list_t = self.get_features(out_list_all)
loss_transfer = torch.zeros((1,)).cuda()
for i in range(len(out_list_s)):
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
for i, n in enumerate(out_list_s):
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
h_start = 0
for j in range(h_start, self.len_seq, 1):
i_start = j - len_win if j - len_win >= 0 else 0
@ -471,7 +469,7 @@ class AdaRNN(nn.Module):
else 1 / (self.len_seq - h_start) * (2 * len_win + 1)
)
loss_transfer = loss_transfer + weight * criterion_transder.compute(
out_list_s[i][:, j, :], out_list_t[i][:, k, :]
n[:, j, :], out_list_t[i][:, k, :]
)
return fc_out, loss_transfer, out_weight_list
@ -484,7 +482,7 @@ class AdaRNN(nn.Module):
out, _ = self.features[i](x_input.float())
x_input = out
out_lis.append(out)
if self.model_type == "AdaRNN" and predict == False:
if self.model_type == "AdaRNN" and predict is False:
out_gate = self.process_gate_weight(x_input, i)
out_weight_list.append(out_gate)
return out, out_lis, out_weight_list
@ -524,10 +522,10 @@ class AdaRNN(nn.Module):
else:
weight = weight_mat
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
for i in range(len(out_list_s)):
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
for i, n in enumerate(out_list_s):
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
for j in range(self.len_seq):
loss_trans = criterion_transder.compute(out_list_s[i][:, j, :], out_list_t[i][:, j, :])
loss_trans = criterion_transder.compute(n[:, j, :], out_list_t[i][:, j, :])
loss_transfer = loss_transfer + weight[i, j] * loss_trans
dist_mat[i, j] = loss_trans
return fc_out, loss_transfer, dist_mat, weight
@ -546,7 +544,7 @@ class AdaRNN(nn.Module):
def predict(self, x):
out = self.gru_features(x, predict=True)
fea = out[0]
if self.use_bottleneck == True:
if self.use_bottleneck is True:
fea_bottleneck = self.bottleneck(fea[:, -1, :])
fc_out = self.fc(fea_bottleneck).squeeze()
else:
@ -572,12 +570,12 @@ class TransferLoss:
Returns:
[tensor] -- transfer loss
"""
if self.loss_type == "mmd_lin" or self.loss_type == "mmd":
if self.loss_type in ("mmd_lin", "mmd"):
mmdloss = MMD_loss(kernel_type="linear")
loss = mmdloss(X, Y)
elif self.loss_type == "coral":
loss = CORAL(X, Y)
elif self.loss_type == "cosine" or self.loss_type == "cos":
elif self.loss_type in ("cosine", "cos"):
loss = 1 - cosine(X, Y)
elif self.loss_type == "kl":
loss = kl_div(X, Y)

Просмотреть файл

@ -20,7 +20,6 @@ from qlib.contrib.model.pytorch_lstm import LSTMModel
from qlib.contrib.model.pytorch_utils import count_parameters
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.dataset.processor import CSRankNorm
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import get_or_create_path

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
@ -150,7 +149,7 @@ class ALSTM(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
@ -312,8 +311,8 @@ class ALSTMModel(nn.Module):
def _build_model(self):
try:
klass = getattr(nn, self.rnn_type.upper())
except:
raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
except Exception as e:
raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e
self.net = nn.Sequential()
self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
self.net.add_module("act", nn.Tanh())

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
@ -20,7 +19,7 @@ from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.utils import ConcatDataset
from ...data.dataset.weight import Reweighter
@ -160,7 +159,7 @@ class ALSTM(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
@ -320,8 +319,8 @@ class ALSTMModel(nn.Module):
def _build_model(self):
try:
klass = getattr(nn, self.rnn_type.upper())
except:
raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
except Exception as e:
raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e
self.net = nn.Sequential()
self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
self.net.add_module("act", nn.Tanh())

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
@ -158,7 +157,7 @@ class GATs(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
@ -263,7 +262,9 @@ class GATs(Model):
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
pretrained_dict = {
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict
} # pylint: disable=E1135
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
@ -19,7 +18,6 @@ from torch.utils.data import Sampler
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...contrib.model.pytorch_lstm import LSTMModel
from ...contrib.model.pytorch_gru import GRUModel
@ -178,7 +176,7 @@ class GATs(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
@ -279,7 +277,9 @@ class GATs(Model):
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
model_dict = self.GAT_model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
pretrained_dict = {
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict
} # pylint: disable=E1135
model_dict.update(pretrained_dict)
self.GAT_model.load_state_dict(model_dict)
self.logger.info("Loading pretrained model Done...")

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
@ -150,7 +149,7 @@ class GRU(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
@ -19,7 +18,6 @@ from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.utils import ConcatDataset
from ...data.dataset.weight import Reweighter
@ -159,7 +157,7 @@ class GRU(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
@ -17,11 +16,9 @@ from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from torch.nn.modules.container import ModuleList
@ -102,7 +99,7 @@ class LocalformerModel(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
@ -18,9 +17,8 @@ import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from torch.nn.modules.container import ModuleList
@ -101,7 +99,7 @@ class LocalformerModel(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
@ -146,7 +145,7 @@ class LSTM(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
@ -18,7 +17,6 @@ import torch.optim as optim
from torch.utils.data import DataLoader
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset.handler import DataHandlerLP
from ...model.utils import ConcatDataset
from ...data.dataset.weight import Reweighter
@ -155,7 +153,7 @@ class LSTM(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -328,6 +328,7 @@ class Net(nn.Module):
dnn_layers = []
drop_input = nn.Dropout(0.05)
dnn_layers.append(drop_input)
hidden_units = None
for i, (_input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
fc = nn.Linear(_input_dim, hidden_units)
activation = nn.LeakyReLU(negative_slope=0.1, inplace=False)
@ -338,7 +339,7 @@ class Net(nn.Module):
dnn_layers.append(drop_input)
fc = nn.Linear(hidden_units, output_dim)
dnn_layers.append(fc)
# optimizer
# optimizer # pylint: disable=W0631
self.dnn_layers = nn.ModuleList(dnn_layers)
self._weight_init()

Просмотреть файл

@ -4,7 +4,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
@ -435,7 +434,7 @@ class SFM(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -3,7 +3,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
@ -378,7 +377,7 @@ class TabnetModel(Model):
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -15,7 +15,6 @@ from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils import weight_norm
from .pytorch_utils import count_parameters
from ...model.base import Model
@ -158,7 +157,7 @@ class TCN(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -158,7 +158,7 @@ class TCN(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -5,20 +5,12 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
import random
from sklearn.metrics import roc_auc_score, mean_squared_error
import logging
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...log import get_module_logger, TimeInspector
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
@ -263,7 +255,7 @@ class TCTS(Model):
x_valid, y_valid = df_valid["feature"], df_valid["label"]
x_test, y_test = df_test["feature"], df_test["label"]
if save_path == None:
if save_path is None:
save_path = get_or_create_path(save_path)
best_loss = np.inf
while best_loss > self.lowest_valid_performance:

Просмотреть файл

@ -6,10 +6,8 @@ import os
import copy
import math
import json
import collections
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
@ -24,7 +22,6 @@ except ImportError:
from tqdm import tqdm
from qlib.utils import get_or_create_path
from qlib.constant import EPS
from qlib.log import get_module_logger
from qlib.model.base import Model

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
@ -17,11 +16,9 @@ from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
# qrun examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml ”
@ -101,7 +98,7 @@ class TransformerModel(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -5,7 +5,6 @@
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
import copy
@ -18,9 +17,8 @@ import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH, TSDatasetH
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
@ -98,7 +96,7 @@ class TransformerModel(Model):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
if self.metric in ("", "loss"):
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)

Просмотреть файл

@ -26,11 +26,11 @@ def count_parameters(models_or_parameters, unit="m"):
else:
counts = sum(v.numel() for v in models_or_parameters)
unit = unit.lower()
if unit == "kb" or unit == "k":
if unit in ("kb", "k"):
counts /= 2 ** 10
elif unit == "mb" or unit == "m":
elif unit in ("mb", "m"):
counts /= 2 ** 20
elif unit == "gb" or unit == "g":
elif unit in ("gb", "g"):
counts /= 2 ** 30
elif unit is not None:
raise ValueError("Unknown unit: {:}".format(unit))

Просмотреть файл

@ -1,6 +1,5 @@
# MIT License
# Copyright (c) 2018 CMU Locus Lab
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm

Просмотреть файл

@ -1,3 +1,5 @@
# pylint: skip-file
'''
TODO:

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: skip-file
import yaml
import pathlib
import pandas as pd

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: skip-file
import random
import pandas as pd
from ...data import D

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: skip-file
import fire
import pandas as pd
import pathlib

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: skip-file
import logging
from ...log import get_module_logger

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: skip-file
import pathlib
import pickle
import yaml

Просмотреть файл

@ -1,12 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
import numpy as np
import pandas as pd
from datetime import datetime
import qlib
from qlib.data import D
from qlib.data.cache import H
from qlib.data.data import Cal
from qlib.data.ops import ElemOperator

Просмотреть файл

@ -34,7 +34,7 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int
{
"Group%d"
% (i + 1): pred_label_drop.groupby(level="datetime")["label"].apply(
lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean()
lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean() # pylint: disable=W0640
)
for i in range(N)
}

Просмотреть файл

@ -282,8 +282,10 @@ class SubplotsGraph:
if self._subplots_kwargs is None:
self._init_subplots_kwargs()
self.__cols = self._subplots_kwargs.get("cols", 2)
self.__rows = self._subplots_kwargs.get("rows", math.ceil(len(self._df.columns) / self.__cols))
self.__cols = self._subplots_kwargs.get("cols", 2) # pylint: disable=W0238
self.__rows = self._subplots_kwargs.get( # pylint: disable=W0238
"rows", math.ceil(len(self._df.columns) / self.__cols)
)
self._sub_graph_data = sub_graph_data
if self._sub_graph_data is None:

Просмотреть файл

@ -10,4 +10,3 @@ class BaseOptimizer(abc.ABC):
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> object:
"""Generate a optimized portfolio allocation"""
pass

Просмотреть файл

@ -3,7 +3,6 @@
import numpy as np
import cvxpy as cp
import pandas as pd
from typing import Union, Optional, Dict, Any, List
@ -156,7 +155,7 @@ class EnhancedIndexingOptimizer(BaseOptimizer):
# factor deviation
if self.f_dev is not None:
cons.extend([v >= -self.f_dev, v <= self.f_dev])
cons.extend([v >= -self.f_dev, v <= self.f_dev]) # pylint: disable=E1130
# total turnover constraint
t_cons = []

Просмотреть файл

@ -6,7 +6,6 @@ This order generator is for strategies based on WeightStrategyBase
"""
from ...backtest.position import Position
from ...backtest.exchange import Exchange
from ...backtest.decision import BaseTradeDecision, TradeDecisionWO
import pandas as pd
import copy

Просмотреть файл

@ -3,7 +3,6 @@
import os
import copy
import warnings
import cvxpy as cp
import numpy as np
import pandas as pd
@ -15,11 +14,10 @@ from qlib.model.base import BaseModel
from qlib.strategy.base import BaseStrategy
from qlib.backtest.position import Position
from qlib.backtest.signal import Signal, create_signal_from
from qlib.backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO
from qlib.log import get_module_logger
from qlib.utils import get_pre_trading_date, load_dataset
from qlib.utils.resam import resam_ts_data
from qlib.contrib.strategy.order_generator import OrderGenWInteract, OrderGenWOInteract
from qlib.contrib.strategy.order_generator import OrderGenWOInteract
from qlib.contrib.strategy.optimizer import EnhancedIndexingOptimizer

Просмотреть файл

@ -0,0 +1 @@
# pylint: skip-file

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: skip-file
import yaml
import copy
import os

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: skip-file
# coding=utf-8
import argparse

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: skip-file
import os
import json
import logging

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: skip-file
from hyperopt import hp

Просмотреть файл

@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: skip-file
import os
import yaml
import json

Просмотреть файл

@ -6,7 +6,6 @@ from __future__ import division
from __future__ import print_function
import abc
import pandas as pd
from ..log import get_module_logger
@ -21,107 +20,107 @@ class Expression(abc.ABC):
return str(self)
def __gt__(self, other):
from .ops import Gt
from .ops import Gt # pylint: disable=C0415
return Gt(self, other)
def __ge__(self, other):
from .ops import Ge
from .ops import Ge # pylint: disable=C0415
return Ge(self, other)
def __lt__(self, other):
from .ops import Lt
from .ops import Lt # pylint: disable=C0415
return Lt(self, other)
def __le__(self, other):
from .ops import Le
from .ops import Le # pylint: disable=C0415
return Le(self, other)
def __eq__(self, other):
from .ops import Eq
from .ops import Eq # pylint: disable=C0415
return Eq(self, other)
def __ne__(self, other):
from .ops import Ne
from .ops import Ne # pylint: disable=C0415
return Ne(self, other)
def __add__(self, other):
from .ops import Add
from .ops import Add # pylint: disable=C0415
return Add(self, other)
def __radd__(self, other):
from .ops import Add
from .ops import Add # pylint: disable=C0415
return Add(other, self)
def __sub__(self, other):
from .ops import Sub
from .ops import Sub # pylint: disable=C0415
return Sub(self, other)
def __rsub__(self, other):
from .ops import Sub
from .ops import Sub # pylint: disable=C0415
return Sub(other, self)
def __mul__(self, other):
from .ops import Mul
from .ops import Mul # pylint: disable=C0415
return Mul(self, other)
def __rmul__(self, other):
from .ops import Mul
from .ops import Mul # pylint: disable=C0415
return Mul(self, other)
def __div__(self, other):
from .ops import Div
from .ops import Div # pylint: disable=C0415
return Div(self, other)
def __rdiv__(self, other):
from .ops import Div
from .ops import Div # pylint: disable=C0415
return Div(other, self)
def __truediv__(self, other):
from .ops import Div
from .ops import Div # pylint: disable=C0415
return Div(self, other)
def __rtruediv__(self, other):
from .ops import Div
from .ops import Div # pylint: disable=C0415
return Div(other, self)
def __pow__(self, other):
from .ops import Power
from .ops import Power # pylint: disable=C0415
return Power(self, other)
def __and__(self, other):
from .ops import And
from .ops import And # pylint: disable=C0415
return And(self, other)
def __rand__(self, other):
from .ops import And
from .ops import And # pylint: disable=C0415
return And(other, self)
def __or__(self, other):
from .ops import Or
from .ops import Or # pylint: disable=C0415
return Or(self, other)
def __ror__(self, other):
from .ops import Or
from .ops import Or # pylint: disable=C0415
return Or(other, self)
@ -144,7 +143,7 @@ class Expression(abc.ABC):
pd.Series
feature series: The index of the series is the calendar index
"""
from .cache import H
from .cache import H # pylint: disable=C0415
# cache
args = str(self), instrument, start_index, end_index, freq
@ -215,7 +214,7 @@ class Feature(Expression):
def _load_internal(self, instrument, start_index, end_index, freq):
# load
from .data import FeatureD
from .data import FeatureD # pylint: disable=C0415
return FeatureD.feature(instrument, str(self), start_index, end_index, freq)
@ -232,5 +231,3 @@ class ExpressionOps(Expression):
This kind of feature will use operator for feature
construction on the fly.
"""
pass

Просмотреть файл

@ -33,8 +33,7 @@ from ..utils import (
from ..log import get_module_logger
from .base import Feature
from .ops import Operators
from .ops import Operators # pylint: disable=W0611
class QlibCacheException(RuntimeError):
@ -229,8 +228,8 @@ class CacheUtils:
try:
d["meta"]["last_visit"] = str(time.time())
d["meta"]["visits"] = d["meta"]["visits"] + 1
except KeyError:
raise KeyError("Unknown meta keyword")
except KeyError as key_e:
raise KeyError("Unknown meta keyword") from key_e
pickle.dump(d, f, protocol=C.dump_protocol_version)
except Exception as e:
get_module_logger("CacheUtils").warning(f"visit {cache_path} cache error: {e}")
@ -239,7 +238,7 @@ class CacheUtils:
def acquire(lock, lock_name):
try:
lock.acquire()
except redis_lock.AlreadyAcquired:
except redis_lock.AlreadyAcquired as lock_acquired:
raise QlibCacheException(
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
You can use the following command to clear your redis keys and rerun your commands:
@ -249,7 +248,7 @@ class CacheUtils:
> quit
If the issue is not resolved, use "keys *" to find if multiple keys exist. If so, try using "flushall" to clear all the keys.
"""
)
) from lock_acquired
@staticmethod
@contextlib.contextmanager
@ -507,7 +506,7 @@ class DiskExpressionCache(ExpressionCache):
_instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower())
cache_path = _instrument_dir.joinpath(_cache_uri)
# get calendar
from .data import Cal
from .data import Cal # pylint: disable=C0415
_calendar = Cal.calendar(freq=freq)
@ -599,7 +598,7 @@ class DiskExpressionCache(ExpressionCache):
last_update_time = d["info"]["last_update"]
# get newest calendar
from .data import Cal, ExpressionD
from .data import Cal, ExpressionD # pylint: disable=C0415
whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq)
# calendar since last updated.
@ -753,7 +752,7 @@ class DiskDatasetCache(DatasetCache):
if disk_cache == 0:
# In this case, server only checks the expression cache.
# The client will load the cache data by itself.
from .data import LocalDatasetProvider
from .data import LocalDatasetProvider # pylint: disable=C0415
LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq)
return ""
@ -895,7 +894,7 @@ class DiskDatasetCache(DatasetCache):
:return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
"""
# get calendar
from .data import Cal
from .data import Cal # pylint: disable=C0415
cache_path = Path(cache_path)
_calendar = Cal.calendar(freq=freq)
@ -970,14 +969,14 @@ class DiskDatasetCache(DatasetCache):
index_data = im.get_index()
self.logger.debug("Updating dataset: {}".format(d))
from .data import Inst
from .data import Inst # pylint: disable=C0415
if Inst.get_inst_type(instruments) == Inst.DICT:
self.logger.info(f"The file {cache_uri} has dict cache. Skip updating")
return 1
# get newest calendar
from .data import Cal
from .data import Cal # pylint: disable=C0415
whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq)
# The calendar since last updated
@ -994,7 +993,7 @@ class DiskDatasetCache(DatasetCache):
current_index = len(whole_calendar) - len(new_calendar) + 1
# To avoid recursive import
from .data import ExpressionD
from .data import ExpressionD # pylint: disable=C0415
# The existing data length
lft_etd = rght_etd = 0

Просмотреть файл

@ -5,17 +5,13 @@
from __future__ import division
from __future__ import print_function
import os
import re
import abc
import time
import copy
import queue
import bisect
import numpy as np
import pandas as pd
from multiprocessing import Pool
from typing import Iterable, Union
from typing import List, Union
# For supporting multiprocessing in outer code, joblib is used
@ -23,13 +19,10 @@ from joblib import delayed
from .cache import H
from ..config import C
from .base import Feature
from .ops import Operators
from .inst_processor import InstProcessor
from ..log import get_module_logger
from ..utils.time import Freq
from .cache import DiskDatasetCache, DiskExpressionCache
from .cache import DiskDatasetCache
from ..utils import (
Wrapper,
init_instance_by_config,
@ -43,6 +36,7 @@ from ..utils import (
time_to_slc_point,
)
from ..utils.paral import ParallelExt
from .ops import Operators # pylint: disable=W0611
class ProviderBackendMixin:
@ -144,10 +138,10 @@ class CalendarProvider(abc.ABC):
if start_time not in calendar_index:
try:
start_time = calendar[bisect.bisect_left(calendar, start_time)]
except IndexError:
except IndexError as index_e:
raise IndexError(
"`start_time` uses a future date, if you want to get future trading days, you can use: `future=True`"
)
) from index_e
start_index = calendar_index[start_time]
if end_time not in calendar_index:
end_time = calendar[bisect.bisect_right(calendar, end_time) - 1]
@ -246,7 +240,7 @@ class InstrumentProvider(abc.ABC):
"""
if isinstance(market, list):
return market
from .filter import SeriesDFilter
from .filter import SeriesDFilter # pylint: disable=C0415
if filter_pipe is None:
filter_pipe = []
@ -672,7 +666,7 @@ class LocalInstrumentProvider(InstrumentProvider, ProviderBackendMixin):
# filter
filter_pipe = instruments["filter_pipe"]
for filter_config in filter_pipe:
from . import filter as F
from . import filter as F # pylint: disable=C0415
filter_t = getattr(F, filter_config["filter_type"]).from_config(filter_config)
_instruments_filtered = filter_t(_instruments_filtered, start_time, end_time, freq)
@ -1003,8 +997,8 @@ class ClientDatasetProvider(DatasetProvider):
if return_uri:
return df, feature_uri
return df
except AttributeError:
raise IOError("Unable to fetch instruments from remote server!")
except AttributeError as attribute_e:
raise IOError("Unable to fetch instruments from remote server!") from attribute_e
class BaseProvider:
@ -1110,7 +1104,7 @@ class ClientProvider(BaseProvider):
return isinstance(instance, cls)
from .client import Client
from .client import Client # pylint: disable=C0415
self.client = Client(C.flask_server, C.flask_port)
self.logger = get_module_logger(self.__class__.__name__)

Просмотреть файл

@ -52,7 +52,6 @@ class Dataset(Serializable):
- User prepare data for model based on previous status.
"""
pass
def prepare(self, **kwargs) -> object:
"""
@ -68,7 +67,6 @@ class Dataset(Serializable):
object:
return the object
"""
pass
class DatasetH(Dataset):
@ -348,7 +346,7 @@ class TSDataSampler:
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
self.data_index = self.data_index[np.where(self.flt_data is True)[0]]
self.idx_map = self.idx_map2arr(self.idx_map)
self.start_idx, self.end_idx = self.data_index.slice_locs(

Просмотреть файл

@ -2,24 +2,16 @@
# Licensed under the MIT License.
# coding=utf-8
import abc
import bisect
import logging
import warnings
from inspect import getfullargspec
from typing import Callable, Union, Tuple, List, Iterator, Optional
import pandas as pd
import numpy as np
from ...log import get_module_logger, TimeInspector
from ...data import D
from ...config import C
from ...utils import parse_config, transform_end_date, init_instance_by_config
from ...utils import init_instance_by_config
from ...utils.serial import Serializable
from .utils import fetch_df_by_index, fetch_df_by_col
from ...utils import lazy_sort_index
from pathlib import Path
from .loader import DataLoader
from . import processor as processor_module
@ -228,7 +220,7 @@ class DataHandler(Serializable):
proc_func: Callable = None,
):
# This method is extracted for sharing in subclasses
from .storage import BaseHandlerStorage
from .storage import BaseHandlerStorage # pylint: disable=C0415
# Following conflictions may occurs
# - Does [20200101", "20210101"] mean selecting this slice or these two days?
@ -627,7 +619,6 @@ class DataHandlerLP(DataHandler):
-------
pd.DataFrame:
"""
from .storage import BaseHandlerStorage
return self._fetch_data(
data_storage=self._get_df_by_key(data_key),

Просмотреть файл

@ -51,7 +51,6 @@ class DataLoader(abc.ABC):
pd.DataFrame:
data load from the under layer source
"""
pass
class DLWParser(DataLoader):
@ -129,7 +128,6 @@ class DLWParser(DataLoader):
pd.DataFrame:
the queried dataframe.
"""
pass
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if self.is_group:
@ -308,7 +306,7 @@ class DataLoaderDH(DataLoader):
is_group will be used to describe whether the key of handler_config is group
"""
from qlib.data.dataset.handler import DataHandler
from qlib.data.dataset.handler import DataHandler # pylint: disable=C0415
if is_group:
self.handlers = {

Просмотреть файл

@ -42,7 +42,6 @@ class Processor(Serializable):
processor, i.e. `df`.
"""
pass
@abc.abstractmethod
def __call__(self, df: pd.DataFrame):
@ -57,7 +56,6 @@ class Processor(Serializable):
df : pd.DataFrame
The raw_df of handler or result from previous processor.
"""
pass
def is_for_infer(self) -> bool:
"""
@ -201,7 +199,7 @@ class MinMaxNorm(Processor):
self.fit_end_time = fit_end_time
self.fields_group = fields_group
def fit(self, df):
def fit(self, df: pd.DataFrame = None):
df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
cols = get_group_columns(df, self.fields_group)
self.min_val = np.nanmin(df[cols].values, axis=0)
@ -232,7 +230,7 @@ class ZScoreNorm(Processor):
self.fit_end_time = fit_end_time
self.fields_group = fields_group
def fit(self, df):
def fit(self, df: pd.DataFrame = None):
df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
cols = get_group_columns(df, self.fields_group)
self.mean_train = np.nanmean(df[cols].values, axis=0)
@ -272,7 +270,7 @@ class RobustZScoreNorm(Processor):
self.fields_group = fields_group
self.clip_outlier = clip_outlier
def fit(self, df):
def fit(self, df: pd.DataFrame = None):
df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
self.cols = get_group_columns(df, self.fields_group)
X = df[self.cols].values
@ -351,6 +349,6 @@ class HashStockFormat(Processor):
"""Process the storage of from df into hasing stock format"""
def __call__(self, df: pd.DataFrame):
from .storage import HasingStockStorage
from .storage import HasingStockStorage # pylint: disable=C0415
return HasingStockStorage.from_df(df)

Просмотреть файл

@ -2,7 +2,7 @@ import pandas as pd
import numpy as np
from .handler import DataHandler
from typing import Tuple, Union, List, Callable
from typing import Union, List, Callable
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
@ -109,7 +109,7 @@ class HasingStockStorage(BaseHandlerStorage):
stock_selector = selector[self.stock_level]
elif isinstance(selector, (list, str)) and self.stock_level == 0:
stock_selector = selector
elif level == "instrument" or level == self.stock_level:
elif level in ("instrument", self.stock_level):
if isinstance(selector, tuple):
stock_selector = selector[0]
elif isinstance(selector, (list, str)):

Просмотреть файл

@ -63,7 +63,7 @@ def fetch_df_by_index(
Data of the given index.
"""
# level = None -> use selector directly
if level == None:
if level is None:
return df.loc(axis=0)[selector]
# Try to get the right index
idx_slc = (selector, slice(None, None))
@ -75,7 +75,7 @@ def fetch_df_by_index(
return df.loc[
pd.IndexSlice[idx_slc],
]
else:
else: # pylint: disable=W0120
return df
else:
return df.loc[
@ -84,7 +84,7 @@ def fetch_df_by_index(
def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:
from .handler import DataHandler
from .handler import DataHandler # pylint: disable=C0415
if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW:
return df
@ -136,7 +136,7 @@ def init_task_handler(task: dict) -> Union[DataHandler, None]:
returns
"""
# avoid recursive import
from .handler import DataHandler
from .handler import DataHandler # pylint: disable=C0415
h_conf = task["dataset"]["kwargs"].get("handler")
if h_conf is not None:

Просмотреть файл

@ -1,13 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
import numpy as np
from typing import Union, List, Tuple
from ...data.dataset import TSDataSampler
from ...data.dataset.utils import get_level_index
from ...utils import lazy_sort_index
class Reweighter:
def __init__(self, *args, **kwargs):

Просмотреть файл

@ -62,7 +62,7 @@ class SeriesDFilter(BaseDFilter):
Override _getFilterSeries to use the rule to filter the series and get a dict of {inst => series}, or override filter_main for more advanced series filter rule
"""
def __init__(self, fstart_time=None, fend_time=None):
def __init__(self, fstart_time=None, fend_time=None, keep=False):
"""Init function for filter base class.
Filter a set of instruments based on a certain rule within a certain period assigned by fstart_time and fend_time.
@ -72,10 +72,13 @@ class SeriesDFilter(BaseDFilter):
the time for the filter rule to start filter the instruments.
fend_time: str
the time for the filter rule to stop filter the instruments.
keep: bool
whether to keep the instruments of which features don't exist in the filter time span.
"""
super(SeriesDFilter, self).__init__()
self.filter_start_time = pd.Timestamp(fstart_time) if fstart_time else None
self.filter_end_time = pd.Timestamp(fend_time) if fend_time else None
self.keep = keep
def _getTimeBound(self, instruments):
"""Get time bound for all instruments.
@ -330,12 +333,9 @@ class ExpressionDFilter(SeriesDFilter):
filter the feature ending by this time.
rule_expression: str
an input expression for the rule.
keep: bool
whether to keep the instruments of which features don't exist in the filter time span.
"""
super(ExpressionDFilter, self).__init__(fstart_time, fend_time)
super(ExpressionDFilter, self).__init__(fstart_time, fend_time, keep=keep)
self.rule_expression = rule_expression
self.keep = keep
def _getFilterSeries(self, instruments, fstart, fend):
# do not use dataset cache

Просмотреть файл

@ -17,7 +17,6 @@ class InstProcessor:
df : pd.DataFrame
The raw_df of handler or result from previous processor.
"""
pass
def __str__(self):
return f"{self.__class__.__name__}:{json.dumps(self.__dict__, sort_keys=True, default=str)}"

Просмотреть файл

@ -5,8 +5,6 @@
from __future__ import division
from __future__ import print_function
import sys
import abc
import numpy as np
import pandas as pd
@ -15,7 +13,6 @@ from scipy.stats import percentileofscore
from .base import Expression, ExpressionOps, Feature
from ..config import C
from ..log import get_module_logger
from ..utils import get_callable_kwargs
@ -331,7 +328,7 @@ class NpPairOperator(PairOperator):
res = getattr(np, self.func)(series_left, series_right)
except ValueError as e:
get_module_logger("ops").debug(warning_info)
raise ValueError(f"{str(e)}. \n\t{warning_info}")
raise ValueError(f"{str(e)}. \n\t{warning_info}") from e
else:
if check_length and len(series_left) != len(series_right):
get_module_logger("ops").debug(warning_info)
@ -1430,21 +1427,20 @@ class PairRolling(ExpressionOps):
return max(left_br, right_br)
def get_extended_window_size(self):
if isinstance(self.feature_left, Expression):
ll, lr = self.feature_left.get_extended_window_size()
else:
ll, lr = 0, 0
if isinstance(self.feature_right, Expression):
rl, rr = self.feature_right.get_extended_window_size()
else:
rl, rr = 0, 0
if self.N == 0:
get_module_logger(self.__class__.__name__).warning(
"The PairRolling(ATTR, 0) will not be accurately calculated"
)
return -np.inf, max(lr, rr)
else:
if isinstance(self.feature_left, Expression):
ll, lr = self.feature_left.get_extended_window_size()
else:
ll, lr = 0, 0
if isinstance(self.feature_right, Expression):
rl, rr = self.feature_right.get_extended_window_size()
else:
rl, rr = 0, 0
return max(ll, rl) + self.N - 1, max(lr, rr)

Просмотреть файл

@ -13,7 +13,7 @@ from .config import C
class MetaLogger(type):
def __new__(mcs, name, bases, attrs):
def __new__(mcs, name, bases, attrs): # pylint: disable=C0204
wrapper_dict = logging.Logger.__dict__.copy()
for key in wrapper_dict:
if key not in attrs and key != "__reduce__":
@ -164,7 +164,7 @@ class LogFilter(logging.Filter):
if isinstance(self.param, str):
allow = not self.match_msg(self.param, record.msg)
elif isinstance(self.param, list):
allow = not any([self.match_msg(p, record.msg) for p in self.param])
allow = not any(self.match_msg(p, record.msg) for p in self.param)
return allow
@ -201,7 +201,7 @@ def set_global_logger_level(level: int, return_orig_handler_level: bool = False)
"""
_handler_level_map = {}
qlib_logger = logging.root.manager.loggerDict.get("qlib", None)
qlib_logger = logging.root.manager.loggerDict.get("qlib", None) # pylint: disable=E1101
if qlib_logger is not None:
for _handler in qlib_logger.handlers:
_handler_level_map[_handler] = _handler.level

Просмотреть файл

@ -13,7 +13,6 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta):
@abc.abstractmethod
def predict(self, *args, **kwargs) -> object:
"""Make predictions after modeling things"""
pass
def __call__(self, *args, **kwargs) -> object:
"""leverage Python syntactic sugar to make the models' behaviors like functions"""

Просмотреть файл

@ -13,7 +13,7 @@ reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
"""
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
from typing import Callable, Union
from typing import Callable
from joblib import Parallel, delayed

Просмотреть файл

@ -27,6 +27,9 @@ class FeatureInt:
class LightGBMFInt(FeatureInt):
"""LightGBM (F)eature (Int)erpreter"""
def __init__(self):
self.model = None
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance
@ -35,6 +38,8 @@ class LightGBMFInt(FeatureInt):
parameters reference:
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance
"""
return pd.Series(self.model.feature_importance(*args, **kwargs), index=self.model.feature_name()).sort_values(
return pd.Series(
self.model.feature_importance(*args, **kwargs), index=self.model.feature_name()
).sort_values( # pylint: disable=E1101
ascending=False
)

Просмотреть файл

@ -4,8 +4,6 @@
import abc
from qlib.model.meta.task import MetaTask
from typing import Dict, Union, List, Tuple, Text
from ...workflow.task.gen import RollingGen, task_generator
from ...data.dataset.handler import DataHandler
from ...utils.serial import Serializable
@ -73,4 +71,3 @@ class MetaTaskDataset(Serializable, metaclass=abc.ABCMeta):
seg : Text
the name of the segment
"""
pass

Просмотреть файл

@ -2,10 +2,8 @@
# Licensed under the MIT License.
import abc
from qlib.contrib.meta.data_selection.dataset import MetaDatasetDS
from typing import Union, List, Tuple
from typing import List
from qlib.model.meta.task import MetaTask
from .dataset import MetaTaskDataset
@ -23,7 +21,6 @@ class MetaModel(metaclass=abc.ABCMeta):
"""
The training process of the meta-model.
"""
pass
@abc.abstractmethod
def inference(self, *args, **kwargs) -> object:
@ -35,7 +32,6 @@ class MetaModel(metaclass=abc.ABCMeta):
object:
Some information to guide the model learning
"""
pass
class MetaTaskModel(MetaModel):

Просмотреть файл

@ -1,9 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
from typing import Union, List, Tuple
from qlib.data.dataset import Dataset
from ...utils import init_instance_by_config

Просмотреть файл

@ -91,7 +91,7 @@ class RiskModel(BaseModel):
"return_decomposed_components" in inspect.getfullargspec(self._predict).args
), "This risk model does not support return decomposed components of the covariance matrix "
F, cov_b, var_u = self._predict(X, return_decomposed_components=True)
F, cov_b, var_u = self._predict(X, return_decomposed_components=True) # pylint: disable=E1123
return F, cov_b, var_u
# estimate covariance

Просмотреть файл

@ -12,17 +12,13 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
"""
import socket
import time
import re
from typing import Callable, List
from tqdm.auto import tqdm
from qlib.data.dataset import Dataset
from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import flatten_dict, get_callable_kwargs, init_instance_by_config, auto_filter_kwargs, fill_placeholder
from qlib.utils import flatten_dict, init_instance_by_config, auto_filter_kwargs, fill_placeholder
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.data.dataset.weight import Reweighter

Просмотреть файл

@ -7,7 +7,6 @@ from typing import Union
from ..backtest.executor import BaseExecutor
from .interpreter import StateInterpreter, ActionInterpreter
from ..utils import init_instance_by_config
from .interpreter import BaseInterpreter
class BaseRLEnv:

Просмотреть файл

@ -6,12 +6,8 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from qlib.backtest.exchange import Exchange
from qlib.backtest.position import BasePosition
from typing import List, Tuple, Union
import pandas as pd
from typing import Tuple, Union
from ..model.base import BaseModel
from ..data.dataset import DatasetH
from ..data.dataset.utils import convert_index_format
from ..rl.interpreter import ActionInterpreter, StateInterpreter
from ..utils import init_instance_by_config
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager

Просмотреть файл

@ -139,8 +139,8 @@ def parse_config(config):
# Check whether the str can be parsed
try:
return yaml.safe_load(config)
except BaseException:
raise ValueError("cannot parse config!")
except BaseException as base_exp:
raise ValueError("cannot parse config!") from base_exp
#################### Other ####################
@ -436,7 +436,7 @@ def is_tradable_date(cur_date):
date : pandas.Timestamp
current date
"""
from ..data import D
from ..data import D # pylint: disable=C0415
return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date())
@ -453,7 +453,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
"""
from ..data import D
from ..data import D # pylint: disable=C0415
start = get_date_by_shift(trading_date, left_shift, future=future)
end = get_date_by_shift(trading_date, right_shift, future=future)
@ -476,7 +476,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="
when align is "left"/"right", it will try to align to left/right nearest trading date before shifting when `trading_date` is not a trading date
"""
from qlib.data import D
from qlib.data import D # pylint: disable=C0415
cal = D.calendar(future=future, freq=freq)
trading_date = pd.to_datetime(trading_date)
@ -529,7 +529,7 @@ def transform_end_date(end_date=None, freq="day"):
date : pandas.Timestamp
current date
"""
from ..data import D
from ..data import D # pylint: disable=C0415
last_date = D.calendar(freq=freq)[-1]
if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)):
@ -810,7 +810,7 @@ def fill_placeholder(config: dict, config_extend: dict):
elif isinstance(now_item, dict):
item_keys = now_item.keys()
for key in item_keys:
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
if isinstance(now_item[key], (list, dict)):
item_queue.append(now_item[key])
tail += 1
elif isinstance(now_item[key], str):

Просмотреть файл

@ -10,16 +10,10 @@ class QlibException(Exception):
class RecorderInitializationError(QlibException):
"""Error type for re-initialization when starting an experiment"""
pass
class LoadObjectError(QlibException):
"""Error type for Recorder when can not load object"""
pass
class ExpAlreadyExistError(Exception):
"""Experiment already exists"""
pass

Просмотреть файл

@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import contextlib
import os
import shutil
import tempfile

Просмотреть файл

@ -153,8 +153,8 @@ class Index:
"""
try:
return self.index_map[self._convert_type(item)]
except IndexError:
raise KeyError(f"{item} can't be found in {self}")
except IndexError as index_e:
raise KeyError(f"{item} can't be found in {self}") from index_e
def __or__(self, other: "Index"):
return Index(idx_list=list(set(self.idx_list) | set(other.idx_list)))

Просмотреть файл

@ -101,8 +101,10 @@ class FileManager(ObjManager):
def create_path(self) -> str:
try:
return tempfile.mkdtemp(prefix=str(C["file_manager_path"]) + os.sep)
except AttributeError:
raise NotImplementedError(f"If path is not given, the `create_path` function should be implemented")
except AttributeError as attribute_e:
raise NotImplementedError(
f"If path is not given, the `create_path` function should be implemented"
) from attribute_e
def save_obj(self, obj, name):
with (self.path / name).open("wb") as f:

Просмотреть файл

@ -70,12 +70,12 @@ def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=No
the feature with higher or equal frequency
"""
from ..data.data import D
from ..data.data import D # pylint: disable=C0415
try:
_result = D.features(instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache)
_freq = freq
except (ValueError, KeyError):
except (ValueError, KeyError) as value_key_e:
_, norm_freq = Freq.parse(freq)
if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]:
try:
@ -88,7 +88,7 @@ def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=No
_result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache)
_freq = "1min"
else:
raise ValueError(f"freq {freq} is not supported")
raise ValueError(f"freq {freq} is not supported") from value_key_e
return _result, _freq
@ -172,7 +172,7 @@ def resam_ts_data(
selector_datetime = slice(start_time, end_time)
from ..data.dataset.utils import get_level_index
from ..data.dataset.utils import get_level_index # pylint: disable=C0415
feature = lazy_sort_index(ts_feature)

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT License.
from contextlib import contextmanager
from typing import Text, Optional, Any, Dict, Text, Optional
from typing import Text, Optional, Any, Dict
from .expm import ExpManager
from .exp import Experiment
from .recorder import Recorder

Просмотреть файл

@ -1,7 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys, os
import sys
import os
from pathlib import Path
import qlib

Просмотреть файл

@ -2,10 +2,10 @@
# Licensed under the MIT License.
from typing import Dict, List, Union
import mlflow, logging
import mlflow
import logging
from mlflow.entities import ViewType
from mlflow.exceptions import MlflowException
from pathlib import Path
from .recorder import Recorder, MLflowRecorder
from ..log import get_module_logger
@ -271,7 +271,7 @@ class MLflowExperiment(Experiment):
return self.active_recorder
def end(self, recorder_status):
def end(self, recorder_status=Recorder.STATUS_S):
if self.active_recorder is not None:
self.active_recorder.end_run(recorder_status)
self.active_recorder = None
@ -299,8 +299,10 @@ class MLflowExperiment(Experiment):
run = self._client.get_run(recorder_id)
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
return recorder
except MlflowException:
raise ValueError("No valid recorder has been found, please make sure the input recorder id is correct.")
except MlflowException as mlflow_exp:
raise ValueError(
"No valid recorder has been found, please make sure the input recorder id is correct."
) from mlflow_exp
elif recorder_name is not None:
logger.warning(
f"Please make sure the recorder name {recorder_name} is unique, we will only return the latest recorder if there exist several matched the given name."
@ -332,7 +334,7 @@ class MLflowExperiment(Experiment):
except MlflowException as e:
raise Exception(
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
)
) from e
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
@ -362,10 +364,10 @@ class MLflowExperiment(Experiment):
)
rids = []
recorders = []
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
for i, n in enumerate(runs):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=n)
if status is None or recorder.status == status:
rids.append(runs[i].info.run_id)
rids.append(n.info.run_id)
recorders.append(recorder)
if rtype == Experiment.RT_D:

Просмотреть файл

@ -6,9 +6,7 @@ import mlflow
from filelock import FileLock
from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCode
from mlflow.entities import ViewType
import os, logging
from pathlib import Path
from contextlib import contextmanager
import os
from typing import Optional, Text
from .exp import MLflowExperiment, Experiment
@ -203,7 +201,7 @@ class ExpManager:
# So we supported it in the interface wrapper
pr = urlparse(self.uri)
if pr.scheme == "file":
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")) as f:
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")) as f: # pylint: disable=E0110
return self.create_exp(experiment_name), True
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
try:
@ -363,7 +361,7 @@ class MLflowExpManager(ExpManager):
experiment_id = self.client.create_experiment(experiment_name)
except MlflowException as e:
if e.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS):
raise ExpAlreadyExistError()
raise ExpAlreadyExistError() from e
raise e
experiment = MLflowExperiment(experiment_id, experiment_name, self.uri)
@ -387,10 +385,10 @@ class MLflowExpManager(ExpManager):
raise MlflowException("No valid experiment has been found.")
experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)
return experiment
except MlflowException:
except MlflowException as e:
raise ValueError(
"No valid experiment has been found, please make sure the input experiment id is correct."
)
) from e
elif experiment_name is not None:
try:
exp = self.client.get_experiment_by_name(experiment_name)
@ -401,9 +399,9 @@ class MLflowExpManager(ExpManager):
except MlflowException as e:
raise ValueError(
"No valid experiment has been found, please make sure the input experiment name is correct."
)
) from e
def search_records(self, experiment_ids, **kwargs):
def search_records(self, experiment_ids=None, **kwargs):
filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string")
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
@ -425,7 +423,7 @@ class MLflowExpManager(ExpManager):
except MlflowException as e:
raise Exception(
f"Error: {e}. Something went wrong when deleting experiment. Please check if the name/id of the experiment is correct."
)
) from e
def list_experiments(self):
# retrieve all the existing experiments

Просмотреть файл

@ -83,15 +83,14 @@ For simplicity
"""
import logging
from typing import Callable, Dict, List, Union
from typing import Callable, List, Union
import pandas as pd
from qlib import get_module_logger
from qlib.data.data import D
from qlib.log import set_global_logger_level
from qlib.model.ens.ensemble import AverageEnsemble
from qlib.model.trainer import DelayTrainerR, Trainer, TrainerR
from qlib.utils import flatten_dict
from qlib.model.trainer import Trainer, TrainerR
from qlib.utils.serial import Serializable
from qlib.workflow.online.strategy import OnlineStrategy
from qlib.workflow.task.collect import MergeCollector

Просмотреть файл

@ -5,9 +5,7 @@
OnlineStrategy module is an element of online serving.
"""
from copy import deepcopy
from typing import List, Tuple, Union
from qlib.data.data import D
from typing import List, Union
from qlib.log import get_module_logger
from qlib.model.ens.group import RollingGroup
from qlib.utils import transform_end_date

Просмотреть файл

@ -148,7 +148,7 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
self.rmdl = loader_cls(rec=record)
latest_date = D.calendar(freq=freq)[-1]
if to_date == None:
if to_date is None:
to_date = latest_date
to_date = pd.Timestamp(to_date)
@ -191,7 +191,9 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
else:
hist_ref = self.hist_ref
start_time_buffer = get_date_by_shift(self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq)
start_time_buffer = get_date_by_shift(
self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq # pylint: disable=E1130
)
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
seg = {"test": (start_time, self.to_date)}
return self.rmdl.get_dataset(

Просмотреть файл

@ -8,10 +8,8 @@ This allows us to use efficient submodels as the market-style changing.
"""
from typing import List, Union
from qlib.data.dataset import TSDatasetH
from qlib.log import get_module_logger
from qlib.utils import get_callable_kwargs
from qlib.utils.exceptions import LoadObjectError
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше