зеркало из https://github.com/microsoft/qlib.git
Fix pylint (#888)
* add_pylint_to_workflow * fix-pylint * fix_pylinterror * fix-issue
This commit is contained in:
Родитель
635632e4ed
Коммит
144e1e2459
|
@ -33,7 +33,37 @@ jobs:
|
|||
- name: Install Qlib with pip
|
||||
run: |
|
||||
pip install numpy==1.19.5 ruamel.yaml
|
||||
pip install pyqlib --ignore-installed
|
||||
pip install pyqlib --ignore-installed
|
||||
|
||||
# Check Qlib with pylint
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
# C0209: consider-using-f-string
|
||||
# R0402: consider-using-from-import
|
||||
# R1705: no-else-return
|
||||
# R1710: inconsistent-return-statements
|
||||
# R1725: super-with-arguments
|
||||
# R1735: use-dict-literal
|
||||
# W0102: dangerous-default-value
|
||||
# W0212: protected-access
|
||||
# W0221: arguments-differ
|
||||
# W0223: abstract-method
|
||||
# W0231: super-init-not-called
|
||||
# W0237: arguments-renamed
|
||||
# W0612: unused-variable
|
||||
# W0621: redefined-outer-name
|
||||
# W0622: redefined-builtin
|
||||
# FIXME: specify exception type
|
||||
# W0703: broad-except
|
||||
# W1309: f-string-without-interpolation
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
- name: Check Qlib with pylint
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install pylint
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0201,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
|
|
|
@ -30,8 +30,8 @@ def init(default_conf="client", **kwargs):
|
|||
When using the recorder, skip_if_reg can set to True to avoid loss of recorder.
|
||||
|
||||
"""
|
||||
from .config import C
|
||||
from .data.cache import H
|
||||
from .config import C # pylint: disable=C0415
|
||||
from .data.cache import H # pylint: disable=C0415
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
logger = get_module_logger("Initialization", level=logging.INFO)
|
||||
|
@ -85,7 +85,7 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
|||
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
|
||||
# If the provider uri looks like this 172.23.233.89//data/csdesign'
|
||||
# It will be a nfs path. The client provider will be used
|
||||
if not auto_mount:
|
||||
if not auto_mount: # pylint: disable=R1702
|
||||
if not Path(mount_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
|
@ -139,8 +139,10 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
|||
if not _is_mount:
|
||||
try:
|
||||
Path(mount_path).mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
|
||||
except Exception as e:
|
||||
raise OSError(
|
||||
f"Failed to create directory {mount_path}, please create {mount_path} manually!"
|
||||
) from e
|
||||
|
||||
# check nfs-common
|
||||
command_res = os.popen("dpkg -l | grep nfs-common")
|
||||
|
|
|
@ -171,8 +171,8 @@ def get_strategy_executor(
|
|||
# NOTE:
|
||||
# - for avoiding recursive import
|
||||
# - typing annotations is not reliable
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
from ..strategy.base import BaseStrategy # pylint: disable=C0415
|
||||
from .executor import BaseExecutor # pylint: disable=C0415
|
||||
|
||||
trade_account = create_account_instance(
|
||||
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
from typing import Dict, List, Tuple, TYPE_CHECKING
|
||||
from typing import Dict, List, Tuple
|
||||
from qlib.utils import init_instance_by_config
|
||||
import pandas as pd
|
||||
|
||||
from .position import BasePosition, InfPosition, Position
|
||||
from .position import BasePosition
|
||||
from .report import PortfolioMetrics, Indicator
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
|
|
|
@ -7,19 +7,18 @@ from qlib.data.data import Cal
|
|||
from qlib.utils.time import concat_date_time, epsilon_change
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
from typing import ClassVar, Optional, Union, List, Tuple
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar, Optional, Union, List, Set, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class OrderDir(IntEnum):
|
||||
|
@ -418,7 +417,7 @@ class BaseTradeDecision:
|
|||
return kwargs["default_value"]
|
||||
else:
|
||||
# Default to get full index
|
||||
raise NotImplementedError(f"The decision didn't provide an index range")
|
||||
raise NotImplementedError(f"The decision didn't provide an index range") from NotImplementedError
|
||||
|
||||
# clip index
|
||||
if getattr(self, "total_step", None) is not None:
|
||||
|
|
|
@ -3,13 +3,13 @@
|
|||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .account import Account
|
||||
|
||||
from qlib.backtest.position import BasePosition, Position
|
||||
import random
|
||||
from typing import List, Tuple, Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
@ -18,7 +18,7 @@ from ..config import C
|
|||
from ..constant import REG_CN
|
||||
from ..log import get_module_logger
|
||||
from .decision import Order, OrderDir, OrderHelper
|
||||
from .high_performance_ds import BaseQuote, PandasQuote, NumpyQuote
|
||||
from .high_performance_ds import BaseQuote, NumpyQuote
|
||||
|
||||
|
||||
class Exchange:
|
||||
|
|
|
@ -1,22 +1,18 @@
|
|||
from abc import abstractclassmethod, abstractmethod
|
||||
from abc import abstractmethod
|
||||
import copy
|
||||
from qlib.backtest.position import BasePosition
|
||||
from qlib.log import get_module_logger
|
||||
from types import GeneratorType
|
||||
from qlib.backtest.account import Account
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Union
|
||||
from collections import defaultdict
|
||||
|
||||
from qlib.backtest.report import Indicator
|
||||
|
||||
from .decision import EmptyTradeDecision, Order, BaseTradeDecision
|
||||
from .decision import Order, BaseTradeDecision
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
|
||||
|
||||
|
@ -193,7 +189,8 @@ class BaseExecutor:
|
|||
pass
|
||||
return return_value.get("execute_result")
|
||||
|
||||
@abstractclassmethod
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
"""
|
||||
Please refer to the doc of collect_data
|
||||
|
@ -453,7 +450,6 @@ class NestedExecutor(BaseExecutor):
|
|||
inner_exe_res :
|
||||
the execution result of inner task
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_all_executors(self):
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
|
|
|
@ -2,8 +2,6 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
|
@ -538,7 +536,7 @@ class InfPosition(BasePosition):
|
|||
def get_stock_amount_dict(self) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool) -> Dict:
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
|
|
|
@ -10,11 +10,8 @@ import numpy as np
|
|||
import pandas as pd
|
||||
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from .decision import IdxTradeRange
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from ..data import D
|
||||
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
import qlib.utils.index_data as idd
|
||||
|
|
|
@ -388,13 +388,11 @@ class QlibConfig(Config):
|
|||
default_conf : str
|
||||
the default config template chosen by user: "server", "client"
|
||||
"""
|
||||
from .utils import set_log_with_config, get_module_logger, can_use_cache
|
||||
from .utils import set_log_with_config, get_module_logger, can_use_cache # pylint: disable=C0415
|
||||
|
||||
self.reset()
|
||||
|
||||
_logging_config = self.logging_config
|
||||
if "logging_config" in kwargs:
|
||||
_logging_config = kwargs["logging_config"]
|
||||
_logging_config = kwargs.get("logging_config", self.logging_config)
|
||||
|
||||
# set global config
|
||||
if _logging_config:
|
||||
|
@ -433,11 +431,11 @@ class QlibConfig(Config):
|
|||
)
|
||||
|
||||
def register(self):
|
||||
from .utils import init_instance_by_config
|
||||
from .data.ops import register_all_ops
|
||||
from .data.data import register_all_wrappers
|
||||
from .workflow import R, QlibRecorder
|
||||
from .workflow.utils import experiment_exit_handler
|
||||
from .utils import init_instance_by_config # pylint: disable=C0415
|
||||
from .data.ops import register_all_ops # pylint: disable=C0415
|
||||
from .data.data import register_all_wrappers # pylint: disable=C0415
|
||||
from .workflow import R, QlibRecorder # pylint: disable=C0415
|
||||
from .workflow.utils import experiment_exit_handler # pylint: disable=C0415
|
||||
|
||||
register_all_ops(self)
|
||||
register_all_wrappers(self)
|
||||
|
@ -454,7 +452,7 @@ class QlibConfig(Config):
|
|||
self._registered = True
|
||||
|
||||
def reset_qlib_version(self):
|
||||
import qlib
|
||||
import qlib # pylint: disable=C0415
|
||||
|
||||
reset_version = self.get("qlib_reset_version", None)
|
||||
if reset_version is not None:
|
||||
|
|
|
@ -7,8 +7,7 @@ import warnings
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
@ -16,7 +15,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|||
|
||||
def _to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return torch.tensor(x, dtype=torch.float, device=device)
|
||||
return torch.tensor(x, dtype=torch.float, device=device) # pylint: disable=E1101
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -5,9 +5,7 @@ from ...data.dataset.handler import DataHandlerLP
|
|||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_callable_kwargs
|
||||
from ...data.dataset import processor as processor_module
|
||||
from ...log import TimeInspector
|
||||
from inspect import getfullargspec
|
||||
import copy
|
||||
|
||||
|
||||
def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
from ...log import TimeInspector
|
||||
from ...utils.serial import Serializable
|
||||
from ...data.dataset.processor import Processor, get_group_columns
|
||||
|
||||
|
||||
|
|
|
@ -5,12 +5,10 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.stats import spearmanr, pearsonr
|
||||
|
||||
|
||||
from ..data import D
|
||||
|
||||
from collections import OrderedDict
|
||||
|
@ -243,4 +241,4 @@ def get_rank_ic(a, b):
|
|||
|
||||
|
||||
def get_normal_ic(a, b):
|
||||
return pearsonr(a, b).correlation
|
||||
return pearsonr(a, b)[0]
|
||||
|
|
|
@ -1,24 +1,23 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from copy import deepcopy
|
||||
from qlib.data.dataset.utils import init_task_handler
|
||||
from qlib.utils.data import deepcopy_basic_type
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from typing import Dict, List, Union, Text, Tuple
|
||||
from qlib.data.dataset.handler import DataHandler
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from joblib import Parallel, delayed
|
||||
from qlib.model.meta.dataset import MetaTaskDataset
|
||||
from qlib.model.trainer import task_train, TrainerR
|
||||
from qlib.data.dataset import DatasetH
|
||||
from tqdm.auto import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from joblib import Parallel, delayed # pylint: disable=E0401
|
||||
from typing import Dict, List, Union, Text, Tuple
|
||||
from qlib.data.dataset.utils import init_task_handler
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from qlib.model.meta.dataset import MetaTaskDataset
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config
|
||||
from qlib.utils.data import deepcopy_basic_type
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class InternalData:
|
||||
|
|
|
@ -1,28 +1,26 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from qlib.model.meta.task import MetaTask
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import optim
|
||||
from tqdm.auto import tqdm
|
||||
import collections
|
||||
import copy
|
||||
from typing import Union, List, Tuple, Dict
|
||||
from typing import Union, List
|
||||
|
||||
from ....data.dataset.weight import Reweighter
|
||||
from ....model.meta.dataset import MetaTaskDataset
|
||||
from ....model.meta.model import MetaModel, MetaTaskModel
|
||||
from ....model.meta.model import MetaTaskModel
|
||||
from ....workflow import R
|
||||
|
||||
from .utils import ICLoss
|
||||
from .dataset import MetaDatasetDS
|
||||
from qlib.contrib.meta.data_selection.net import PredNet
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from qlib.contrib.meta.data_selection.net import PredNet
|
||||
|
||||
logger = get_module_logger("data selection")
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
|
||||
|
||||
class ICLoss(nn.Module):
|
||||
|
|
|
@ -101,7 +101,7 @@ class LGBModel(ModelFT, LightGBMFInt):
|
|||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter)
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
|
||||
if dtrain.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
self.model = lgb.train(
|
||||
|
|
|
@ -58,7 +58,7 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
|||
"""
|
||||
Test the signal in high frequency test set
|
||||
"""
|
||||
if self.model == None:
|
||||
if self.model is None:
|
||||
raise ValueError("Model hasn't been trained yet")
|
||||
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
df_test.dropna(inplace=True)
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
import os
|
||||
from pdb import set_trace
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
import copy
|
||||
from typing import Text, Union
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
@ -182,11 +180,11 @@ class ADARNN(Model):
|
|||
continue
|
||||
|
||||
total_loss = torch.zeros(1).cuda()
|
||||
for i in range(len(index)):
|
||||
feature_s = list_feat[index[i][0]]
|
||||
feature_t = list_feat[index[i][1]]
|
||||
label_reg_s = list_label[index[i][0]]
|
||||
label_reg_t = list_label[index[i][1]]
|
||||
for i, n in enumerate(index):
|
||||
feature_s = list_feat[n[0]]
|
||||
feature_t = list_feat[n[1]]
|
||||
label_reg_s = list_label[n[0]]
|
||||
label_reg_t = list_label[n[1]]
|
||||
feature_all = torch.cat((feature_s, feature_t), 0)
|
||||
|
||||
if epoch < self.pre_epoch:
|
||||
|
@ -410,7 +408,7 @@ class AdaRNN(nn.Module):
|
|||
in_size = hidden
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
if use_bottleneck == True: # finance
|
||||
if use_bottleneck is True: # finance
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Linear(n_hiddens[-1], bottleneck_width),
|
||||
nn.Linear(bottleneck_width, bottleneck_width),
|
||||
|
@ -449,7 +447,7 @@ class AdaRNN(nn.Module):
|
|||
def forward_pre_train(self, x, len_win=0):
|
||||
out = self.gru_features(x)
|
||||
fea = out[0] # [2N,L,H]
|
||||
if self.use_bottleneck == True:
|
||||
if self.use_bottleneck is True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
|
@ -458,8 +456,8 @@ class AdaRNN(nn.Module):
|
|||
out_list_all, out_weight_list = out[1], out[2]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
h_start = 0
|
||||
for j in range(h_start, self.len_seq, 1):
|
||||
i_start = j - len_win if j - len_win >= 0 else 0
|
||||
|
@ -471,7 +469,7 @@ class AdaRNN(nn.Module):
|
|||
else 1 / (self.len_seq - h_start) * (2 * len_win + 1)
|
||||
)
|
||||
loss_transfer = loss_transfer + weight * criterion_transder.compute(
|
||||
out_list_s[i][:, j, :], out_list_t[i][:, k, :]
|
||||
n[:, j, :], out_list_t[i][:, k, :]
|
||||
)
|
||||
return fc_out, loss_transfer, out_weight_list
|
||||
|
||||
|
@ -484,7 +482,7 @@ class AdaRNN(nn.Module):
|
|||
out, _ = self.features[i](x_input.float())
|
||||
x_input = out
|
||||
out_lis.append(out)
|
||||
if self.model_type == "AdaRNN" and predict == False:
|
||||
if self.model_type == "AdaRNN" and predict is False:
|
||||
out_gate = self.process_gate_weight(x_input, i)
|
||||
out_weight_list.append(out_gate)
|
||||
return out, out_lis, out_weight_list
|
||||
|
@ -524,10 +522,10 @@ class AdaRNN(nn.Module):
|
|||
else:
|
||||
weight = weight_mat
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
for j in range(self.len_seq):
|
||||
loss_trans = criterion_transder.compute(out_list_s[i][:, j, :], out_list_t[i][:, j, :])
|
||||
loss_trans = criterion_transder.compute(n[:, j, :], out_list_t[i][:, j, :])
|
||||
loss_transfer = loss_transfer + weight[i, j] * loss_trans
|
||||
dist_mat[i, j] = loss_trans
|
||||
return fc_out, loss_transfer, dist_mat, weight
|
||||
|
@ -546,7 +544,7 @@ class AdaRNN(nn.Module):
|
|||
def predict(self, x):
|
||||
out = self.gru_features(x, predict=True)
|
||||
fea = out[0]
|
||||
if self.use_bottleneck == True:
|
||||
if self.use_bottleneck is True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
|
@ -572,12 +570,12 @@ class TransferLoss:
|
|||
Returns:
|
||||
[tensor] -- transfer loss
|
||||
"""
|
||||
if self.loss_type == "mmd_lin" or self.loss_type == "mmd":
|
||||
if self.loss_type in ("mmd_lin", "mmd"):
|
||||
mmdloss = MMD_loss(kernel_type="linear")
|
||||
loss = mmdloss(X, Y)
|
||||
elif self.loss_type == "coral":
|
||||
loss = CORAL(X, Y)
|
||||
elif self.loss_type == "cosine" or self.loss_type == "cos":
|
||||
elif self.loss_type in ("cosine", "cos"):
|
||||
loss = 1 - cosine(X, Y)
|
||||
elif self.loss_type == "kl":
|
||||
loss = kl_div(X, Y)
|
||||
|
|
|
@ -20,7 +20,6 @@ from qlib.contrib.model.pytorch_lstm import LSTMModel
|
|||
from qlib.contrib.model.pytorch_utils import count_parameters
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.processor import CSRankNorm
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import get_or_create_path
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
|
@ -150,7 +149,7 @@ class ALSTM(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
@ -312,8 +311,8 @@ class ALSTMModel(nn.Module):
|
|||
def _build_model(self):
|
||||
try:
|
||||
klass = getattr(nn, self.rnn_type.upper())
|
||||
except:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
|
||||
except Exception as e:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e
|
||||
self.net = nn.Sequential()
|
||||
self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
|
||||
self.net.add_module("act", nn.Tanh())
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
|
@ -20,7 +19,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.utils import ConcatDataset
|
||||
from ...data.dataset.weight import Reweighter
|
||||
|
@ -160,7 +159,7 @@ class ALSTM(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
@ -320,8 +319,8 @@ class ALSTMModel(nn.Module):
|
|||
def _build_model(self):
|
||||
try:
|
||||
klass = getattr(nn, self.rnn_type.upper())
|
||||
except:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
|
||||
except Exception as e:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e
|
||||
self.net = nn.Sequential()
|
||||
self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
|
||||
self.net.add_module("act", nn.Tanh())
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
|
@ -158,7 +157,7 @@ class GATs(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
@ -263,7 +262,9 @@ class GATs(Model):
|
|||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict
|
||||
} # pylint: disable=E1135
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
@ -19,7 +18,6 @@ from torch.utils.data import Sampler
|
|||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
|
@ -178,7 +176,7 @@ class GATs(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
@ -279,7 +277,9 @@ class GATs(Model):
|
|||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict
|
||||
} # pylint: disable=E1135
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
|
@ -150,7 +149,7 @@ class GRU(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
@ -19,7 +18,6 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.utils import ConcatDataset
|
||||
from ...data.dataset.weight import Reweighter
|
||||
|
@ -159,7 +157,7 @@ class GRU(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
|
@ -17,11 +16,9 @@ from ...log import get_module_logger
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from torch.nn.modules.container import ModuleList
|
||||
|
||||
|
@ -102,7 +99,7 @@ class LocalformerModel(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
@ -18,9 +17,8 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from torch.nn.modules.container import ModuleList
|
||||
|
||||
|
@ -101,7 +99,7 @@ class LocalformerModel(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
|
@ -146,7 +145,7 @@ class LSTM(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
@ -18,7 +17,6 @@ import torch.optim as optim
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.utils import ConcatDataset
|
||||
from ...data.dataset.weight import Reweighter
|
||||
|
@ -155,7 +153,7 @@ class LSTM(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
|
|
@ -328,6 +328,7 @@ class Net(nn.Module):
|
|||
dnn_layers = []
|
||||
drop_input = nn.Dropout(0.05)
|
||||
dnn_layers.append(drop_input)
|
||||
hidden_units = None
|
||||
for i, (_input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
|
||||
fc = nn.Linear(_input_dim, hidden_units)
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=False)
|
||||
|
@ -338,7 +339,7 @@ class Net(nn.Module):
|
|||
dnn_layers.append(drop_input)
|
||||
fc = nn.Linear(hidden_units, output_dim)
|
||||
dnn_layers.append(fc)
|
||||
# optimizer
|
||||
# optimizer # pylint: disable=W0631
|
||||
self.dnn_layers = nn.ModuleList(dnn_layers)
|
||||
self._weight_init()
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
|
@ -435,7 +434,7 @@ class SFM(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
|
@ -378,7 +377,7 @@ class TabnetModel(Model):
|
|||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ from ...log import get_module_logger
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
|
@ -158,7 +157,7 @@ class TCN(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
|
|
@ -158,7 +158,7 @@ class TCN(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
|
|
@ -5,20 +5,12 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
import random
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -263,7 +255,7 @@ class TCTS(Model):
|
|||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
|
||||
if save_path == None:
|
||||
if save_path is None:
|
||||
save_path = get_or_create_path(save_path)
|
||||
best_loss = np.inf
|
||||
while best_loss > self.lowest_valid_performance:
|
||||
|
|
|
@ -6,10 +6,8 @@ import os
|
|||
import copy
|
||||
import math
|
||||
import json
|
||||
import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
|
@ -24,7 +22,6 @@ except ImportError:
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.constant import EPS
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
|
@ -17,11 +16,9 @@ from ...log import get_module_logger
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
# qrun examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml ”
|
||||
|
@ -101,7 +98,7 @@ class TransformerModel(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
@ -18,9 +17,8 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
|
@ -98,7 +96,7 @@ class TransformerModel(Model):
|
|||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
|
|
@ -26,11 +26,11 @@ def count_parameters(models_or_parameters, unit="m"):
|
|||
else:
|
||||
counts = sum(v.numel() for v in models_or_parameters)
|
||||
unit = unit.lower()
|
||||
if unit == "kb" or unit == "k":
|
||||
if unit in ("kb", "k"):
|
||||
counts /= 2 ** 10
|
||||
elif unit == "mb" or unit == "m":
|
||||
elif unit in ("mb", "m"):
|
||||
counts /= 2 ** 20
|
||||
elif unit == "gb" or unit == "g":
|
||||
elif unit in ("gb", "g"):
|
||||
counts /= 2 ** 30
|
||||
elif unit is not None:
|
||||
raise ValueError("Unknown unit: {:}".format(unit))
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# MIT License
|
||||
# Copyright (c) 2018 CMU Locus Lab
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# pylint: skip-file
|
||||
|
||||
'''
|
||||
TODO:
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import yaml
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import random
|
||||
import pandas as pd
|
||||
from ...data import D
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
import pathlib
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import logging
|
||||
|
||||
from ...log import get_module_logger
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import pathlib
|
||||
import pickle
|
||||
import yaml
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.data.cache import H
|
||||
from qlib.data.data import Cal
|
||||
from qlib.data.ops import ElemOperator
|
||||
|
|
|
@ -34,7 +34,7 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int
|
|||
{
|
||||
"Group%d"
|
||||
% (i + 1): pred_label_drop.groupby(level="datetime")["label"].apply(
|
||||
lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean()
|
||||
lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean() # pylint: disable=W0640
|
||||
)
|
||||
for i in range(N)
|
||||
}
|
||||
|
|
|
@ -282,8 +282,10 @@ class SubplotsGraph:
|
|||
if self._subplots_kwargs is None:
|
||||
self._init_subplots_kwargs()
|
||||
|
||||
self.__cols = self._subplots_kwargs.get("cols", 2)
|
||||
self.__rows = self._subplots_kwargs.get("rows", math.ceil(len(self._df.columns) / self.__cols))
|
||||
self.__cols = self._subplots_kwargs.get("cols", 2) # pylint: disable=W0238
|
||||
self.__rows = self._subplots_kwargs.get( # pylint: disable=W0238
|
||||
"rows", math.ceil(len(self._df.columns) / self.__cols)
|
||||
)
|
||||
|
||||
self._sub_graph_data = sub_graph_data
|
||||
if self._sub_graph_data is None:
|
||||
|
|
|
@ -10,4 +10,3 @@ class BaseOptimizer(abc.ABC):
|
|||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
"""Generate a optimized portfolio allocation"""
|
||||
pass
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
import numpy as np
|
||||
import cvxpy as cp
|
||||
import pandas as pd
|
||||
|
||||
from typing import Union, Optional, Dict, Any, List
|
||||
|
||||
|
@ -156,7 +155,7 @@ class EnhancedIndexingOptimizer(BaseOptimizer):
|
|||
|
||||
# factor deviation
|
||||
if self.f_dev is not None:
|
||||
cons.extend([v >= -self.f_dev, v <= self.f_dev])
|
||||
cons.extend([v >= -self.f_dev, v <= self.f_dev]) # pylint: disable=E1130
|
||||
|
||||
# total turnover constraint
|
||||
t_cons = []
|
||||
|
|
|
@ -6,7 +6,6 @@ This order generator is for strategies based on WeightStrategyBase
|
|||
"""
|
||||
from ...backtest.position import Position
|
||||
from ...backtest.exchange import Exchange
|
||||
from ...backtest.decision import BaseTradeDecision, TradeDecisionWO
|
||||
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
import os
|
||||
import copy
|
||||
import warnings
|
||||
import cvxpy as cp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
@ -15,11 +14,10 @@ from qlib.model.base import BaseModel
|
|||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.position import Position
|
||||
from qlib.backtest.signal import Signal, create_signal_from
|
||||
from qlib.backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
|
||||
from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import get_pre_trading_date, load_dataset
|
||||
from qlib.utils.resam import resam_ts_data
|
||||
from qlib.contrib.strategy.order_generator import OrderGenWInteract, OrderGenWOInteract
|
||||
from qlib.contrib.strategy.order_generator import OrderGenWOInteract
|
||||
from qlib.contrib.strategy.optimizer import EnhancedIndexingOptimizer
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
# pylint: skip-file
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import yaml
|
||||
import copy
|
||||
import os
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
# coding=utf-8
|
||||
|
||||
import argparse
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
from hyperopt import hp
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
|
|
|
@ -6,7 +6,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import pandas as pd
|
||||
|
||||
from ..log import get_module_logger
|
||||
|
||||
|
@ -21,107 +20,107 @@ class Expression(abc.ABC):
|
|||
return str(self)
|
||||
|
||||
def __gt__(self, other):
|
||||
from .ops import Gt
|
||||
from .ops import Gt # pylint: disable=C0415
|
||||
|
||||
return Gt(self, other)
|
||||
|
||||
def __ge__(self, other):
|
||||
from .ops import Ge
|
||||
from .ops import Ge # pylint: disable=C0415
|
||||
|
||||
return Ge(self, other)
|
||||
|
||||
def __lt__(self, other):
|
||||
from .ops import Lt
|
||||
from .ops import Lt # pylint: disable=C0415
|
||||
|
||||
return Lt(self, other)
|
||||
|
||||
def __le__(self, other):
|
||||
from .ops import Le
|
||||
from .ops import Le # pylint: disable=C0415
|
||||
|
||||
return Le(self, other)
|
||||
|
||||
def __eq__(self, other):
|
||||
from .ops import Eq
|
||||
from .ops import Eq # pylint: disable=C0415
|
||||
|
||||
return Eq(self, other)
|
||||
|
||||
def __ne__(self, other):
|
||||
from .ops import Ne
|
||||
from .ops import Ne # pylint: disable=C0415
|
||||
|
||||
return Ne(self, other)
|
||||
|
||||
def __add__(self, other):
|
||||
from .ops import Add
|
||||
from .ops import Add # pylint: disable=C0415
|
||||
|
||||
return Add(self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
from .ops import Add
|
||||
from .ops import Add # pylint: disable=C0415
|
||||
|
||||
return Add(other, self)
|
||||
|
||||
def __sub__(self, other):
|
||||
from .ops import Sub
|
||||
from .ops import Sub # pylint: disable=C0415
|
||||
|
||||
return Sub(self, other)
|
||||
|
||||
def __rsub__(self, other):
|
||||
from .ops import Sub
|
||||
from .ops import Sub # pylint: disable=C0415
|
||||
|
||||
return Sub(other, self)
|
||||
|
||||
def __mul__(self, other):
|
||||
from .ops import Mul
|
||||
from .ops import Mul # pylint: disable=C0415
|
||||
|
||||
return Mul(self, other)
|
||||
|
||||
def __rmul__(self, other):
|
||||
from .ops import Mul
|
||||
from .ops import Mul # pylint: disable=C0415
|
||||
|
||||
return Mul(self, other)
|
||||
|
||||
def __div__(self, other):
|
||||
from .ops import Div
|
||||
from .ops import Div # pylint: disable=C0415
|
||||
|
||||
return Div(self, other)
|
||||
|
||||
def __rdiv__(self, other):
|
||||
from .ops import Div
|
||||
from .ops import Div # pylint: disable=C0415
|
||||
|
||||
return Div(other, self)
|
||||
|
||||
def __truediv__(self, other):
|
||||
from .ops import Div
|
||||
from .ops import Div # pylint: disable=C0415
|
||||
|
||||
return Div(self, other)
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
from .ops import Div
|
||||
from .ops import Div # pylint: disable=C0415
|
||||
|
||||
return Div(other, self)
|
||||
|
||||
def __pow__(self, other):
|
||||
from .ops import Power
|
||||
from .ops import Power # pylint: disable=C0415
|
||||
|
||||
return Power(self, other)
|
||||
|
||||
def __and__(self, other):
|
||||
from .ops import And
|
||||
from .ops import And # pylint: disable=C0415
|
||||
|
||||
return And(self, other)
|
||||
|
||||
def __rand__(self, other):
|
||||
from .ops import And
|
||||
from .ops import And # pylint: disable=C0415
|
||||
|
||||
return And(other, self)
|
||||
|
||||
def __or__(self, other):
|
||||
from .ops import Or
|
||||
from .ops import Or # pylint: disable=C0415
|
||||
|
||||
return Or(self, other)
|
||||
|
||||
def __ror__(self, other):
|
||||
from .ops import Or
|
||||
from .ops import Or # pylint: disable=C0415
|
||||
|
||||
return Or(other, self)
|
||||
|
||||
|
@ -144,7 +143,7 @@ class Expression(abc.ABC):
|
|||
pd.Series
|
||||
feature series: The index of the series is the calendar index
|
||||
"""
|
||||
from .cache import H
|
||||
from .cache import H # pylint: disable=C0415
|
||||
|
||||
# cache
|
||||
args = str(self), instrument, start_index, end_index, freq
|
||||
|
@ -215,7 +214,7 @@ class Feature(Expression):
|
|||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
# load
|
||||
from .data import FeatureD
|
||||
from .data import FeatureD # pylint: disable=C0415
|
||||
|
||||
return FeatureD.feature(instrument, str(self), start_index, end_index, freq)
|
||||
|
||||
|
@ -232,5 +231,3 @@ class ExpressionOps(Expression):
|
|||
This kind of feature will use operator for feature
|
||||
construction on the fly.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
|
|
@ -33,8 +33,7 @@ from ..utils import (
|
|||
|
||||
from ..log import get_module_logger
|
||||
from .base import Feature
|
||||
|
||||
from .ops import Operators
|
||||
from .ops import Operators # pylint: disable=W0611
|
||||
|
||||
|
||||
class QlibCacheException(RuntimeError):
|
||||
|
@ -229,8 +228,8 @@ class CacheUtils:
|
|||
try:
|
||||
d["meta"]["last_visit"] = str(time.time())
|
||||
d["meta"]["visits"] = d["meta"]["visits"] + 1
|
||||
except KeyError:
|
||||
raise KeyError("Unknown meta keyword")
|
||||
except KeyError as key_e:
|
||||
raise KeyError("Unknown meta keyword") from key_e
|
||||
pickle.dump(d, f, protocol=C.dump_protocol_version)
|
||||
except Exception as e:
|
||||
get_module_logger("CacheUtils").warning(f"visit {cache_path} cache error: {e}")
|
||||
|
@ -239,7 +238,7 @@ class CacheUtils:
|
|||
def acquire(lock, lock_name):
|
||||
try:
|
||||
lock.acquire()
|
||||
except redis_lock.AlreadyAcquired:
|
||||
except redis_lock.AlreadyAcquired as lock_acquired:
|
||||
raise QlibCacheException(
|
||||
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
|
||||
You can use the following command to clear your redis keys and rerun your commands:
|
||||
|
@ -249,7 +248,7 @@ class CacheUtils:
|
|||
> quit
|
||||
If the issue is not resolved, use "keys *" to find if multiple keys exist. If so, try using "flushall" to clear all the keys.
|
||||
"""
|
||||
)
|
||||
) from lock_acquired
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
|
@ -507,7 +506,7 @@ class DiskExpressionCache(ExpressionCache):
|
|||
_instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower())
|
||||
cache_path = _instrument_dir.joinpath(_cache_uri)
|
||||
# get calendar
|
||||
from .data import Cal
|
||||
from .data import Cal # pylint: disable=C0415
|
||||
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
|
||||
|
@ -599,7 +598,7 @@ class DiskExpressionCache(ExpressionCache):
|
|||
last_update_time = d["info"]["last_update"]
|
||||
|
||||
# get newest calendar
|
||||
from .data import Cal, ExpressionD
|
||||
from .data import Cal, ExpressionD # pylint: disable=C0415
|
||||
|
||||
whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq)
|
||||
# calendar since last updated.
|
||||
|
@ -753,7 +752,7 @@ class DiskDatasetCache(DatasetCache):
|
|||
if disk_cache == 0:
|
||||
# In this case, server only checks the expression cache.
|
||||
# The client will load the cache data by itself.
|
||||
from .data import LocalDatasetProvider
|
||||
from .data import LocalDatasetProvider # pylint: disable=C0415
|
||||
|
||||
LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq)
|
||||
return ""
|
||||
|
@ -895,7 +894,7 @@ class DiskDatasetCache(DatasetCache):
|
|||
:return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
|
||||
"""
|
||||
# get calendar
|
||||
from .data import Cal
|
||||
from .data import Cal # pylint: disable=C0415
|
||||
|
||||
cache_path = Path(cache_path)
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
|
@ -970,14 +969,14 @@ class DiskDatasetCache(DatasetCache):
|
|||
index_data = im.get_index()
|
||||
|
||||
self.logger.debug("Updating dataset: {}".format(d))
|
||||
from .data import Inst
|
||||
from .data import Inst # pylint: disable=C0415
|
||||
|
||||
if Inst.get_inst_type(instruments) == Inst.DICT:
|
||||
self.logger.info(f"The file {cache_uri} has dict cache. Skip updating")
|
||||
return 1
|
||||
|
||||
# get newest calendar
|
||||
from .data import Cal
|
||||
from .data import Cal # pylint: disable=C0415
|
||||
|
||||
whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq)
|
||||
# The calendar since last updated
|
||||
|
@ -994,7 +993,7 @@ class DiskDatasetCache(DatasetCache):
|
|||
current_index = len(whole_calendar) - len(new_calendar) + 1
|
||||
|
||||
# To avoid recursive import
|
||||
from .data import ExpressionD
|
||||
from .data import ExpressionD # pylint: disable=C0415
|
||||
|
||||
# The existing data length
|
||||
lft_etd = rght_etd = 0
|
||||
|
|
|
@ -5,17 +5,13 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import abc
|
||||
import time
|
||||
import copy
|
||||
import queue
|
||||
import bisect
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from multiprocessing import Pool
|
||||
from typing import Iterable, Union
|
||||
from typing import List, Union
|
||||
|
||||
# For supporting multiprocessing in outer code, joblib is used
|
||||
|
@ -23,13 +19,10 @@ from joblib import delayed
|
|||
|
||||
from .cache import H
|
||||
from ..config import C
|
||||
from .base import Feature
|
||||
from .ops import Operators
|
||||
from .inst_processor import InstProcessor
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..utils.time import Freq
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from .cache import DiskDatasetCache
|
||||
from ..utils import (
|
||||
Wrapper,
|
||||
init_instance_by_config,
|
||||
|
@ -43,6 +36,7 @@ from ..utils import (
|
|||
time_to_slc_point,
|
||||
)
|
||||
from ..utils.paral import ParallelExt
|
||||
from .ops import Operators # pylint: disable=W0611
|
||||
|
||||
|
||||
class ProviderBackendMixin:
|
||||
|
@ -144,10 +138,10 @@ class CalendarProvider(abc.ABC):
|
|||
if start_time not in calendar_index:
|
||||
try:
|
||||
start_time = calendar[bisect.bisect_left(calendar, start_time)]
|
||||
except IndexError:
|
||||
except IndexError as index_e:
|
||||
raise IndexError(
|
||||
"`start_time` uses a future date, if you want to get future trading days, you can use: `future=True`"
|
||||
)
|
||||
) from index_e
|
||||
start_index = calendar_index[start_time]
|
||||
if end_time not in calendar_index:
|
||||
end_time = calendar[bisect.bisect_right(calendar, end_time) - 1]
|
||||
|
@ -246,7 +240,7 @@ class InstrumentProvider(abc.ABC):
|
|||
"""
|
||||
if isinstance(market, list):
|
||||
return market
|
||||
from .filter import SeriesDFilter
|
||||
from .filter import SeriesDFilter # pylint: disable=C0415
|
||||
|
||||
if filter_pipe is None:
|
||||
filter_pipe = []
|
||||
|
@ -672,7 +666,7 @@ class LocalInstrumentProvider(InstrumentProvider, ProviderBackendMixin):
|
|||
# filter
|
||||
filter_pipe = instruments["filter_pipe"]
|
||||
for filter_config in filter_pipe:
|
||||
from . import filter as F
|
||||
from . import filter as F # pylint: disable=C0415
|
||||
|
||||
filter_t = getattr(F, filter_config["filter_type"]).from_config(filter_config)
|
||||
_instruments_filtered = filter_t(_instruments_filtered, start_time, end_time, freq)
|
||||
|
@ -1003,8 +997,8 @@ class ClientDatasetProvider(DatasetProvider):
|
|||
if return_uri:
|
||||
return df, feature_uri
|
||||
return df
|
||||
except AttributeError:
|
||||
raise IOError("Unable to fetch instruments from remote server!")
|
||||
except AttributeError as attribute_e:
|
||||
raise IOError("Unable to fetch instruments from remote server!") from attribute_e
|
||||
|
||||
|
||||
class BaseProvider:
|
||||
|
@ -1110,7 +1104,7 @@ class ClientProvider(BaseProvider):
|
|||
|
||||
return isinstance(instance, cls)
|
||||
|
||||
from .client import Client
|
||||
from .client import Client # pylint: disable=C0415
|
||||
|
||||
self.client = Client(C.flask_server, C.flask_port)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
|
|
@ -52,7 +52,6 @@ class Dataset(Serializable):
|
|||
|
||||
- User prepare data for model based on previous status.
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare(self, **kwargs) -> object:
|
||||
"""
|
||||
|
@ -68,7 +67,6 @@ class Dataset(Serializable):
|
|||
object:
|
||||
return the object
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DatasetH(Dataset):
|
||||
|
@ -348,7 +346,7 @@ class TSDataSampler:
|
|||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
self.data_index = self.data_index[np.where(self.flt_data is True)[0]]
|
||||
self.idx_map = self.idx_map2arr(self.idx_map)
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(
|
||||
|
|
|
@ -2,24 +2,16 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
# coding=utf-8
|
||||
import abc
|
||||
import bisect
|
||||
import logging
|
||||
import warnings
|
||||
from inspect import getfullargspec
|
||||
from typing import Callable, Union, Tuple, List, Iterator, Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils import init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import fetch_df_by_index, fetch_df_by_col
|
||||
from ...utils import lazy_sort_index
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
from . import processor as processor_module
|
||||
|
@ -228,7 +220,7 @@ class DataHandler(Serializable):
|
|||
proc_func: Callable = None,
|
||||
):
|
||||
# This method is extracted for sharing in subclasses
|
||||
from .storage import BaseHandlerStorage
|
||||
from .storage import BaseHandlerStorage # pylint: disable=C0415
|
||||
|
||||
# Following conflictions may occurs
|
||||
# - Does [20200101", "20210101"] mean selecting this slice or these two days?
|
||||
|
@ -627,7 +619,6 @@ class DataHandlerLP(DataHandler):
|
|||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
from .storage import BaseHandlerStorage
|
||||
|
||||
return self._fetch_data(
|
||||
data_storage=self._get_df_by_key(data_key),
|
||||
|
|
|
@ -51,7 +51,6 @@ class DataLoader(abc.ABC):
|
|||
pd.DataFrame:
|
||||
data load from the under layer source
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DLWParser(DataLoader):
|
||||
|
@ -129,7 +128,6 @@ class DLWParser(DataLoader):
|
|||
pd.DataFrame:
|
||||
the queried dataframe.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if self.is_group:
|
||||
|
@ -308,7 +306,7 @@ class DataLoaderDH(DataLoader):
|
|||
is_group will be used to describe whether the key of handler_config is group
|
||||
|
||||
"""
|
||||
from qlib.data.dataset.handler import DataHandler
|
||||
from qlib.data.dataset.handler import DataHandler # pylint: disable=C0415
|
||||
|
||||
if is_group:
|
||||
self.handlers = {
|
||||
|
|
|
@ -42,7 +42,6 @@ class Processor(Serializable):
|
|||
processor, i.e. `df`.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, df: pd.DataFrame):
|
||||
|
@ -57,7 +56,6 @@ class Processor(Serializable):
|
|||
df : pd.DataFrame
|
||||
The raw_df of handler or result from previous processor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_for_infer(self) -> bool:
|
||||
"""
|
||||
|
@ -201,7 +199,7 @@ class MinMaxNorm(Processor):
|
|||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
|
||||
def fit(self, df):
|
||||
def fit(self, df: pd.DataFrame = None):
|
||||
df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
self.min_val = np.nanmin(df[cols].values, axis=0)
|
||||
|
@ -232,7 +230,7 @@ class ZScoreNorm(Processor):
|
|||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
|
||||
def fit(self, df):
|
||||
def fit(self, df: pd.DataFrame = None):
|
||||
df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
self.mean_train = np.nanmean(df[cols].values, axis=0)
|
||||
|
@ -272,7 +270,7 @@ class RobustZScoreNorm(Processor):
|
|||
self.fields_group = fields_group
|
||||
self.clip_outlier = clip_outlier
|
||||
|
||||
def fit(self, df):
|
||||
def fit(self, df: pd.DataFrame = None):
|
||||
df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
self.cols = get_group_columns(df, self.fields_group)
|
||||
X = df[self.cols].values
|
||||
|
@ -351,6 +349,6 @@ class HashStockFormat(Processor):
|
|||
"""Process the storage of from df into hasing stock format"""
|
||||
|
||||
def __call__(self, df: pd.DataFrame):
|
||||
from .storage import HasingStockStorage
|
||||
from .storage import HasingStockStorage # pylint: disable=C0415
|
||||
|
||||
return HasingStockStorage.from_df(df)
|
||||
|
|
|
@ -2,7 +2,7 @@ import pandas as pd
|
|||
import numpy as np
|
||||
|
||||
from .handler import DataHandler
|
||||
from typing import Tuple, Union, List, Callable
|
||||
from typing import Union, List, Callable
|
||||
|
||||
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
|
||||
|
||||
|
@ -109,7 +109,7 @@ class HasingStockStorage(BaseHandlerStorage):
|
|||
stock_selector = selector[self.stock_level]
|
||||
elif isinstance(selector, (list, str)) and self.stock_level == 0:
|
||||
stock_selector = selector
|
||||
elif level == "instrument" or level == self.stock_level:
|
||||
elif level in ("instrument", self.stock_level):
|
||||
if isinstance(selector, tuple):
|
||||
stock_selector = selector[0]
|
||||
elif isinstance(selector, (list, str)):
|
||||
|
|
|
@ -63,7 +63,7 @@ def fetch_df_by_index(
|
|||
Data of the given index.
|
||||
"""
|
||||
# level = None -> use selector directly
|
||||
if level == None:
|
||||
if level is None:
|
||||
return df.loc(axis=0)[selector]
|
||||
# Try to get the right index
|
||||
idx_slc = (selector, slice(None, None))
|
||||
|
@ -75,7 +75,7 @@ def fetch_df_by_index(
|
|||
return df.loc[
|
||||
pd.IndexSlice[idx_slc],
|
||||
]
|
||||
else:
|
||||
else: # pylint: disable=W0120
|
||||
return df
|
||||
else:
|
||||
return df.loc[
|
||||
|
@ -84,7 +84,7 @@ def fetch_df_by_index(
|
|||
|
||||
|
||||
def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:
|
||||
from .handler import DataHandler
|
||||
from .handler import DataHandler # pylint: disable=C0415
|
||||
|
||||
if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW:
|
||||
return df
|
||||
|
@ -136,7 +136,7 @@ def init_task_handler(task: dict) -> Union[DataHandler, None]:
|
|||
returns
|
||||
"""
|
||||
# avoid recursive import
|
||||
from .handler import DataHandler
|
||||
from .handler import DataHandler # pylint: disable=C0415
|
||||
|
||||
h_conf = task["dataset"]["kwargs"].get("handler")
|
||||
if h_conf is not None:
|
||||
|
|
|
@ -1,13 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Union, List, Tuple
|
||||
from ...data.dataset import TSDataSampler
|
||||
from ...data.dataset.utils import get_level_index
|
||||
from ...utils import lazy_sort_index
|
||||
|
||||
|
||||
class Reweighter:
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
@ -62,7 +62,7 @@ class SeriesDFilter(BaseDFilter):
|
|||
Override _getFilterSeries to use the rule to filter the series and get a dict of {inst => series}, or override filter_main for more advanced series filter rule
|
||||
"""
|
||||
|
||||
def __init__(self, fstart_time=None, fend_time=None):
|
||||
def __init__(self, fstart_time=None, fend_time=None, keep=False):
|
||||
"""Init function for filter base class.
|
||||
Filter a set of instruments based on a certain rule within a certain period assigned by fstart_time and fend_time.
|
||||
|
||||
|
@ -72,10 +72,13 @@ class SeriesDFilter(BaseDFilter):
|
|||
the time for the filter rule to start filter the instruments.
|
||||
fend_time: str
|
||||
the time for the filter rule to stop filter the instruments.
|
||||
keep: bool
|
||||
whether to keep the instruments of which features don't exist in the filter time span.
|
||||
"""
|
||||
super(SeriesDFilter, self).__init__()
|
||||
self.filter_start_time = pd.Timestamp(fstart_time) if fstart_time else None
|
||||
self.filter_end_time = pd.Timestamp(fend_time) if fend_time else None
|
||||
self.keep = keep
|
||||
|
||||
def _getTimeBound(self, instruments):
|
||||
"""Get time bound for all instruments.
|
||||
|
@ -330,12 +333,9 @@ class ExpressionDFilter(SeriesDFilter):
|
|||
filter the feature ending by this time.
|
||||
rule_expression: str
|
||||
an input expression for the rule.
|
||||
keep: bool
|
||||
whether to keep the instruments of which features don't exist in the filter time span.
|
||||
"""
|
||||
super(ExpressionDFilter, self).__init__(fstart_time, fend_time)
|
||||
super(ExpressionDFilter, self).__init__(fstart_time, fend_time, keep=keep)
|
||||
self.rule_expression = rule_expression
|
||||
self.keep = keep
|
||||
|
||||
def _getFilterSeries(self, instruments, fstart, fend):
|
||||
# do not use dataset cache
|
||||
|
|
|
@ -17,7 +17,6 @@ class InstProcessor:
|
|||
df : pd.DataFrame
|
||||
The raw_df of handler or result from previous processor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.__class__.__name__}:{json.dumps(self.__dict__, sort_keys=True, default=str)}"
|
||||
|
|
|
@ -5,8 +5,6 @@
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import abc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
@ -15,7 +13,6 @@ from scipy.stats import percentileofscore
|
|||
|
||||
from .base import Expression, ExpressionOps, Feature
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
from ..utils import get_callable_kwargs
|
||||
|
||||
|
@ -331,7 +328,7 @@ class NpPairOperator(PairOperator):
|
|||
res = getattr(np, self.func)(series_left, series_right)
|
||||
except ValueError as e:
|
||||
get_module_logger("ops").debug(warning_info)
|
||||
raise ValueError(f"{str(e)}. \n\t{warning_info}")
|
||||
raise ValueError(f"{str(e)}. \n\t{warning_info}") from e
|
||||
else:
|
||||
if check_length and len(series_left) != len(series_right):
|
||||
get_module_logger("ops").debug(warning_info)
|
||||
|
@ -1430,21 +1427,20 @@ class PairRolling(ExpressionOps):
|
|||
return max(left_br, right_br)
|
||||
|
||||
def get_extended_window_size(self):
|
||||
if isinstance(self.feature_left, Expression):
|
||||
ll, lr = self.feature_left.get_extended_window_size()
|
||||
else:
|
||||
ll, lr = 0, 0
|
||||
if isinstance(self.feature_right, Expression):
|
||||
rl, rr = self.feature_right.get_extended_window_size()
|
||||
else:
|
||||
rl, rr = 0, 0
|
||||
if self.N == 0:
|
||||
get_module_logger(self.__class__.__name__).warning(
|
||||
"The PairRolling(ATTR, 0) will not be accurately calculated"
|
||||
)
|
||||
return -np.inf, max(lr, rr)
|
||||
else:
|
||||
if isinstance(self.feature_left, Expression):
|
||||
ll, lr = self.feature_left.get_extended_window_size()
|
||||
else:
|
||||
ll, lr = 0, 0
|
||||
|
||||
if isinstance(self.feature_right, Expression):
|
||||
rl, rr = self.feature_right.get_extended_window_size()
|
||||
else:
|
||||
rl, rr = 0, 0
|
||||
return max(ll, rl) + self.N - 1, max(lr, rr)
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from .config import C
|
|||
|
||||
|
||||
class MetaLogger(type):
|
||||
def __new__(mcs, name, bases, attrs):
|
||||
def __new__(mcs, name, bases, attrs): # pylint: disable=C0204
|
||||
wrapper_dict = logging.Logger.__dict__.copy()
|
||||
for key in wrapper_dict:
|
||||
if key not in attrs and key != "__reduce__":
|
||||
|
@ -164,7 +164,7 @@ class LogFilter(logging.Filter):
|
|||
if isinstance(self.param, str):
|
||||
allow = not self.match_msg(self.param, record.msg)
|
||||
elif isinstance(self.param, list):
|
||||
allow = not any([self.match_msg(p, record.msg) for p in self.param])
|
||||
allow = not any(self.match_msg(p, record.msg) for p in self.param)
|
||||
return allow
|
||||
|
||||
|
||||
|
@ -201,7 +201,7 @@ def set_global_logger_level(level: int, return_orig_handler_level: bool = False)
|
|||
|
||||
"""
|
||||
_handler_level_map = {}
|
||||
qlib_logger = logging.root.manager.loggerDict.get("qlib", None)
|
||||
qlib_logger = logging.root.manager.loggerDict.get("qlib", None) # pylint: disable=E1101
|
||||
if qlib_logger is not None:
|
||||
for _handler in qlib_logger.handlers:
|
||||
_handler_level_map[_handler] = _handler.level
|
||||
|
|
|
@ -13,7 +13,6 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta):
|
|||
@abc.abstractmethod
|
||||
def predict(self, *args, **kwargs) -> object:
|
||||
"""Make predictions after modeling things"""
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
"""leverage Python syntactic sugar to make the models' behaviors like functions"""
|
||||
|
|
|
@ -13,7 +13,7 @@ reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
|
|||
"""
|
||||
|
||||
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
|
||||
from typing import Callable, Union
|
||||
from typing import Callable
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
|
||||
|
|
|
@ -27,6 +27,9 @@ class FeatureInt:
|
|||
class LightGBMFInt(FeatureInt):
|
||||
"""LightGBM (F)eature (Int)erpreter"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
|
||||
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
|
@ -35,6 +38,8 @@ class LightGBMFInt(FeatureInt):
|
|||
parameters reference:
|
||||
https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance
|
||||
"""
|
||||
return pd.Series(self.model.feature_importance(*args, **kwargs), index=self.model.feature_name()).sort_values(
|
||||
return pd.Series(
|
||||
self.model.feature_importance(*args, **kwargs), index=self.model.feature_name()
|
||||
).sort_values( # pylint: disable=E1101
|
||||
ascending=False
|
||||
)
|
||||
|
|
|
@ -4,8 +4,6 @@
|
|||
import abc
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from typing import Dict, Union, List, Tuple, Text
|
||||
from ...workflow.task.gen import RollingGen, task_generator
|
||||
from ...data.dataset.handler import DataHandler
|
||||
from ...utils.serial import Serializable
|
||||
|
||||
|
||||
|
@ -73,4 +71,3 @@ class MetaTaskDataset(Serializable, metaclass=abc.ABCMeta):
|
|||
seg : Text
|
||||
the name of the segment
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -2,10 +2,8 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
from qlib.contrib.meta.data_selection.dataset import MetaDatasetDS
|
||||
from typing import Union, List, Tuple
|
||||
from typing import List
|
||||
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from .dataset import MetaTaskDataset
|
||||
|
||||
|
||||
|
@ -23,7 +21,6 @@ class MetaModel(metaclass=abc.ABCMeta):
|
|||
"""
|
||||
The training process of the meta-model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def inference(self, *args, **kwargs) -> object:
|
||||
|
@ -35,7 +32,6 @@ class MetaModel(metaclass=abc.ABCMeta):
|
|||
object:
|
||||
Some information to guide the model learning
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MetaTaskModel(MetaModel):
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
from typing import Union, List, Tuple
|
||||
|
||||
from qlib.data.dataset import Dataset
|
||||
from ...utils import init_instance_by_config
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ class RiskModel(BaseModel):
|
|||
"return_decomposed_components" in inspect.getfullargspec(self._predict).args
|
||||
), "This risk model does not support return decomposed components of the covariance matrix "
|
||||
|
||||
F, cov_b, var_u = self._predict(X, return_decomposed_components=True)
|
||||
F, cov_b, var_u = self._predict(X, return_decomposed_components=True) # pylint: disable=E1123
|
||||
return F, cov_b, var_u
|
||||
|
||||
# estimate covariance
|
||||
|
|
|
@ -12,17 +12,13 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
|
|||
"""
|
||||
|
||||
import socket
|
||||
import time
|
||||
import re
|
||||
from typing import Callable, List
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, get_callable_kwargs, init_instance_by_config, auto_filter_kwargs, fill_placeholder
|
||||
from qlib.utils import flatten_dict, init_instance_by_config, auto_filter_kwargs, fill_placeholder
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
|
|
|
@ -7,7 +7,6 @@ from typing import Union
|
|||
from ..backtest.executor import BaseExecutor
|
||||
from .interpreter import StateInterpreter, ActionInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
from .interpreter import BaseInterpreter
|
||||
|
||||
|
||||
class BaseRLEnv:
|
||||
|
|
|
@ -6,12 +6,8 @@ from typing import TYPE_CHECKING
|
|||
if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
from typing import List, Tuple, Union
|
||||
import pandas as pd
|
||||
from typing import Tuple, Union
|
||||
|
||||
from ..model.base import BaseModel
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
|
|
|
@ -139,8 +139,8 @@ def parse_config(config):
|
|||
# Check whether the str can be parsed
|
||||
try:
|
||||
return yaml.safe_load(config)
|
||||
except BaseException:
|
||||
raise ValueError("cannot parse config!")
|
||||
except BaseException as base_exp:
|
||||
raise ValueError("cannot parse config!") from base_exp
|
||||
|
||||
|
||||
#################### Other ####################
|
||||
|
@ -436,7 +436,7 @@ def is_tradable_date(cur_date):
|
|||
date : pandas.Timestamp
|
||||
current date
|
||||
"""
|
||||
from ..data import D
|
||||
from ..data import D # pylint: disable=C0415
|
||||
|
||||
return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date())
|
||||
|
||||
|
@ -453,7 +453,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
|
|||
|
||||
"""
|
||||
|
||||
from ..data import D
|
||||
from ..data import D # pylint: disable=C0415
|
||||
|
||||
start = get_date_by_shift(trading_date, left_shift, future=future)
|
||||
end = get_date_by_shift(trading_date, right_shift, future=future)
|
||||
|
@ -476,7 +476,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="
|
|||
when align is "left"/"right", it will try to align to left/right nearest trading date before shifting when `trading_date` is not a trading date
|
||||
|
||||
"""
|
||||
from qlib.data import D
|
||||
from qlib.data import D # pylint: disable=C0415
|
||||
|
||||
cal = D.calendar(future=future, freq=freq)
|
||||
trading_date = pd.to_datetime(trading_date)
|
||||
|
@ -529,7 +529,7 @@ def transform_end_date(end_date=None, freq="day"):
|
|||
date : pandas.Timestamp
|
||||
current date
|
||||
"""
|
||||
from ..data import D
|
||||
from ..data import D # pylint: disable=C0415
|
||||
|
||||
last_date = D.calendar(freq=freq)[-1]
|
||||
if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)):
|
||||
|
@ -810,7 +810,7 @@ def fill_placeholder(config: dict, config_extend: dict):
|
|||
elif isinstance(now_item, dict):
|
||||
item_keys = now_item.keys()
|
||||
for key in item_keys:
|
||||
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
|
||||
if isinstance(now_item[key], (list, dict)):
|
||||
item_queue.append(now_item[key])
|
||||
tail += 1
|
||||
elif isinstance(now_item[key], str):
|
||||
|
|
|
@ -10,16 +10,10 @@ class QlibException(Exception):
|
|||
class RecorderInitializationError(QlibException):
|
||||
"""Error type for re-initialization when starting an experiment"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LoadObjectError(QlibException):
|
||||
"""Error type for Recorder when can not load object"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExpAlreadyExistError(Exception):
|
||||
"""Experiment already exists"""
|
||||
|
||||
pass
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
|
|
@ -153,8 +153,8 @@ class Index:
|
|||
"""
|
||||
try:
|
||||
return self.index_map[self._convert_type(item)]
|
||||
except IndexError:
|
||||
raise KeyError(f"{item} can't be found in {self}")
|
||||
except IndexError as index_e:
|
||||
raise KeyError(f"{item} can't be found in {self}") from index_e
|
||||
|
||||
def __or__(self, other: "Index"):
|
||||
return Index(idx_list=list(set(self.idx_list) | set(other.idx_list)))
|
||||
|
|
|
@ -101,8 +101,10 @@ class FileManager(ObjManager):
|
|||
def create_path(self) -> str:
|
||||
try:
|
||||
return tempfile.mkdtemp(prefix=str(C["file_manager_path"]) + os.sep)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(f"If path is not given, the `create_path` function should be implemented")
|
||||
except AttributeError as attribute_e:
|
||||
raise NotImplementedError(
|
||||
f"If path is not given, the `create_path` function should be implemented"
|
||||
) from attribute_e
|
||||
|
||||
def save_obj(self, obj, name):
|
||||
with (self.path / name).open("wb") as f:
|
||||
|
|
|
@ -70,12 +70,12 @@ def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=No
|
|||
the feature with higher or equal frequency
|
||||
"""
|
||||
|
||||
from ..data.data import D
|
||||
from ..data.data import D # pylint: disable=C0415
|
||||
|
||||
try:
|
||||
_result = D.features(instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache)
|
||||
_freq = freq
|
||||
except (ValueError, KeyError):
|
||||
except (ValueError, KeyError) as value_key_e:
|
||||
_, norm_freq = Freq.parse(freq)
|
||||
if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]:
|
||||
try:
|
||||
|
@ -88,7 +88,7 @@ def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=No
|
|||
_result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache)
|
||||
_freq = "1min"
|
||||
else:
|
||||
raise ValueError(f"freq {freq} is not supported")
|
||||
raise ValueError(f"freq {freq} is not supported") from value_key_e
|
||||
return _result, _freq
|
||||
|
||||
|
||||
|
@ -172,7 +172,7 @@ def resam_ts_data(
|
|||
|
||||
selector_datetime = slice(start_time, end_time)
|
||||
|
||||
from ..data.dataset.utils import get_level_index
|
||||
from ..data.dataset.utils import get_level_index # pylint: disable=C0415
|
||||
|
||||
feature = lazy_sort_index(ts_feature)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Text, Optional, Any, Dict, Text, Optional
|
||||
from typing import Text, Optional, Any, Dict
|
||||
from .expm import ExpManager
|
||||
from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys, os
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Dict, List, Union
|
||||
import mlflow, logging
|
||||
import mlflow
|
||||
import logging
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
from pathlib import Path
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
|
@ -271,7 +271,7 @@ class MLflowExperiment(Experiment):
|
|||
|
||||
return self.active_recorder
|
||||
|
||||
def end(self, recorder_status):
|
||||
def end(self, recorder_status=Recorder.STATUS_S):
|
||||
if self.active_recorder is not None:
|
||||
self.active_recorder.end_run(recorder_status)
|
||||
self.active_recorder = None
|
||||
|
@ -299,8 +299,10 @@ class MLflowExperiment(Experiment):
|
|||
run = self._client.get_run(recorder_id)
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
|
||||
return recorder
|
||||
except MlflowException:
|
||||
raise ValueError("No valid recorder has been found, please make sure the input recorder id is correct.")
|
||||
except MlflowException as mlflow_exp:
|
||||
raise ValueError(
|
||||
"No valid recorder has been found, please make sure the input recorder id is correct."
|
||||
) from mlflow_exp
|
||||
elif recorder_name is not None:
|
||||
logger.warning(
|
||||
f"Please make sure the recorder name {recorder_name} is unique, we will only return the latest recorder if there exist several matched the given name."
|
||||
|
@ -332,7 +334,7 @@ class MLflowExperiment(Experiment):
|
|||
except MlflowException as e:
|
||||
raise Exception(
|
||||
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
|
||||
)
|
||||
) from e
|
||||
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
|
@ -362,10 +364,10 @@ class MLflowExperiment(Experiment):
|
|||
)
|
||||
rids = []
|
||||
recorders = []
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
for i, n in enumerate(runs):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=n)
|
||||
if status is None or recorder.status == status:
|
||||
rids.append(runs[i].info.run_id)
|
||||
rids.append(n.info.run_id)
|
||||
recorders.append(recorder)
|
||||
|
||||
if rtype == Experiment.RT_D:
|
||||
|
|
|
@ -6,9 +6,7 @@ import mlflow
|
|||
from filelock import FileLock
|
||||
from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCode
|
||||
from mlflow.entities import ViewType
|
||||
import os, logging
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
from typing import Optional, Text
|
||||
|
||||
from .exp import MLflowExperiment, Experiment
|
||||
|
@ -203,7 +201,7 @@ class ExpManager:
|
|||
# So we supported it in the interface wrapper
|
||||
pr = urlparse(self.uri)
|
||||
if pr.scheme == "file":
|
||||
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")) as f:
|
||||
with FileLock(os.path.join(pr.netloc, pr.path, "filelock")) as f: # pylint: disable=E0110
|
||||
return self.create_exp(experiment_name), True
|
||||
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
|
||||
try:
|
||||
|
@ -363,7 +361,7 @@ class MLflowExpManager(ExpManager):
|
|||
experiment_id = self.client.create_experiment(experiment_name)
|
||||
except MlflowException as e:
|
||||
if e.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS):
|
||||
raise ExpAlreadyExistError()
|
||||
raise ExpAlreadyExistError() from e
|
||||
raise e
|
||||
|
||||
experiment = MLflowExperiment(experiment_id, experiment_name, self.uri)
|
||||
|
@ -387,10 +385,10 @@ class MLflowExpManager(ExpManager):
|
|||
raise MlflowException("No valid experiment has been found.")
|
||||
experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri)
|
||||
return experiment
|
||||
except MlflowException:
|
||||
except MlflowException as e:
|
||||
raise ValueError(
|
||||
"No valid experiment has been found, please make sure the input experiment id is correct."
|
||||
)
|
||||
) from e
|
||||
elif experiment_name is not None:
|
||||
try:
|
||||
exp = self.client.get_experiment_by_name(experiment_name)
|
||||
|
@ -401,9 +399,9 @@ class MLflowExpManager(ExpManager):
|
|||
except MlflowException as e:
|
||||
raise ValueError(
|
||||
"No valid experiment has been found, please make sure the input experiment name is correct."
|
||||
)
|
||||
) from e
|
||||
|
||||
def search_records(self, experiment_ids, **kwargs):
|
||||
def search_records(self, experiment_ids=None, **kwargs):
|
||||
filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string")
|
||||
run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type")
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
|
@ -425,7 +423,7 @@ class MLflowExpManager(ExpManager):
|
|||
except MlflowException as e:
|
||||
raise Exception(
|
||||
f"Error: {e}. Something went wrong when deleting experiment. Please check if the name/id of the experiment is correct."
|
||||
)
|
||||
) from e
|
||||
|
||||
def list_experiments(self):
|
||||
# retrieve all the existing experiments
|
||||
|
|
|
@ -83,15 +83,14 @@ For simplicity
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable, Dict, List, Union
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.data.data import D
|
||||
from qlib.log import set_global_logger_level
|
||||
from qlib.model.ens.ensemble import AverageEnsemble
|
||||
from qlib.model.trainer import DelayTrainerR, Trainer, TrainerR
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.model.trainer import Trainer, TrainerR
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow.online.strategy import OnlineStrategy
|
||||
from qlib.workflow.task.collect import MergeCollector
|
||||
|
|
|
@ -5,9 +5,7 @@
|
|||
OnlineStrategy module is an element of online serving.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple, Union
|
||||
from qlib.data.data import D
|
||||
from typing import List, Union
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.utils import transform_end_date
|
||||
|
|
|
@ -148,7 +148,7 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
|||
self.rmdl = loader_cls(rec=record)
|
||||
|
||||
latest_date = D.calendar(freq=freq)[-1]
|
||||
if to_date == None:
|
||||
if to_date is None:
|
||||
to_date = latest_date
|
||||
to_date = pd.Timestamp(to_date)
|
||||
|
||||
|
@ -191,7 +191,9 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
|||
else:
|
||||
hist_ref = self.hist_ref
|
||||
|
||||
start_time_buffer = get_date_by_shift(self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq)
|
||||
start_time_buffer = get_date_by_shift(
|
||||
self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq # pylint: disable=E1130
|
||||
)
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
seg = {"test": (start_time, self.to_date)}
|
||||
return self.rmdl.get_dataset(
|
||||
|
|
|
@ -8,10 +8,8 @@ This allows us to use efficient submodels as the market-style changing.
|
|||
"""
|
||||
|
||||
from typing import List, Union
|
||||
from qlib.data.dataset import TSDatasetH
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import get_callable_kwargs
|
||||
from qlib.utils.exceptions import LoadObjectError
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче