pylint code refine & Fix nested example (#848)

* refine code by CI

* fix argument error

* fix nested eample
This commit is contained in:
you-n-g 2022-01-14 09:09:21 +08:00 коммит произвёл GitHub
Родитель c3996955ef
Коммит d0113ea7df
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
26 изменённых файлов: 65 добавлений и 68 удалений

Просмотреть файл

@ -155,6 +155,8 @@ class NestedDecisionExecutionWorkflow:
},
}
exp_name = "nested"
port_analysis_config = {
"executor": {
"class": "NestedExecutor",
@ -230,7 +232,7 @@ class NestedDecisionExecutionWorkflow:
qlib.init(provider_uri=provider_uri_map, dataset_cache=None, expression_cache=None)
def _train_model(self, model, dataset):
with R.start(experiment_name="train"):
with R.start(experiment_name=self.exp_name):
R.log_params(**flatten_dict(self.task))
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
@ -257,7 +259,7 @@ class NestedDecisionExecutionWorkflow:
self.port_analysis_config["strategy"] = strategy_config
self.port_analysis_config["backtest"]["benchmark"] = self.benchmark
with R.start(experiment_name="backtest"):
with R.start(experiment_name=self.exp_name, resume=True):
recorder = R.get_recorder()
par = PortAnaRecord(
recorder,
@ -382,7 +384,7 @@ class NestedDecisionExecutionWorkflow:
}
pa_conf["backtest"]["benchmark"] = self.benchmark
with R.start(experiment_name="backtest"):
with R.start(experiment_name=self.exp_name, resume=True):
recorder = R.get_recorder()
par = PortAnaRecord(recorder, pa_conf)
par.generate()

Просмотреть файл

@ -536,7 +536,7 @@ class Exchange:
deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)
if deal_amount == 0:
continue
elif deal_amount > 0:
if deal_amount > 0:
# buy stock
buy_order_list.append(
Order(
@ -687,9 +687,7 @@ class Exchange:
orig_deal_amount = order.deal_amount
order.deal_amount = max(min(vol_limit_min, orig_deal_amount), 0)
if vol_limit_min < orig_deal_amount:
self.logger.debug(
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
)
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
"""return the real order amount after cash limit for buying.

Просмотреть файл

@ -194,7 +194,7 @@ class BaseExecutor:
return return_value.get("execute_result")
@abstractclassmethod
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
"""
Please refer to the doc of collect_data
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into

Просмотреть файл

@ -20,7 +20,7 @@ class BasePosition:
Please refer to the `Position` class for the position
"""
def __init__(self, cash=0.0, *args, **kwargs):
def __init__(self, *args, cash=0.0, **kwargs):
self._settle_type = self.ST_NO
def skip_update(self) -> bool:

Просмотреть файл

@ -156,16 +156,16 @@ def decompose_portofolio(stock_weight_df, stock_group_df, stock_ret_df):
group_weight, stock_weight_in_group = decompose_portofolio_weight(stock_weight_df, stock_group_df)
group_ret = {}
for group_key in stock_weight_in_group:
stock_weight_in_group_start_date = min(stock_weight_in_group[group_key].index)
stock_weight_in_group_end_date = max(stock_weight_in_group[group_key].index)
for group_key, val in stock_weight_in_group.items():
stock_weight_in_group_start_date = min(val.index)
stock_weight_in_group_end_date = max(val.index)
temp_stock_ret_df = stock_ret_df[
(stock_ret_df.index >= stock_weight_in_group_start_date)
& (stock_ret_df.index <= stock_weight_in_group_end_date)
]
group_ret[group_key] = (temp_stock_ret_df * stock_weight_in_group[group_key]).sum(axis=1)
group_ret[group_key] = (temp_stock_ret_df * val).sum(axis=1)
# If no weight is assigned, then the return of group will be np.nan
group_ret[group_key][group_weight[group_key] == 0.0] = np.nan

Просмотреть файл

@ -212,7 +212,8 @@ class PortfolioMetrics:
path: str/ pathlib.Path()
"""
path = pathlib.Path(path)
r = pd.read_csv(open(path, "rb"), index_col=0)
with path.open("rb") as f:
r = pd.read_csv(f, index_col=0)
r.index = pd.DatetimeIndex(r.index)
index = r.index

Просмотреть файл

@ -205,10 +205,7 @@ class BaseInfrastructure:
warnings.warn(f"infra {infra_name} is not found!")
def has(self, infra_name):
if infra_name in self.get_support_infra() and hasattr(self, infra_name):
return True
else:
return False
return infra_name in self.get_support_infra() and hasattr(self, infra_name)
def update(self, other):
support_infra = other.get_support_infra()

Просмотреть файл

@ -63,9 +63,7 @@ def _get_date_parse_fn(target):
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
get_date_parse_fn(20120101)('2017-01-01') => 20170101
"""
if isinstance(target, pd.Timestamp):
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
elif isinstance(target, int):
if isinstance(target, int):
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
elif isinstance(target, str) and len(target) == 8:
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
@ -158,7 +156,7 @@ class MTSDatasetH(DatasetH):
try:
df = self.handler._learn.copy() # use copy otherwise recorder will fail
# FIXME: currently we cannot support switching from `_learn` to `_infer` for inference
except:
except Exception:
warnings.warn("cannot access `_learn`, will load raw data")
df = self.handler._data.copy()
df.index = df.index.swaplevel()

Просмотреть файл

@ -371,7 +371,7 @@ def long_short_backtest(
def t_run():
pred_FN = "./check_pred.csv"
pred = pd.read_csv(pred_FN)
pred: pd.DataFrame = pd.read_csv(pred_FN)
pred["datetime"] = pd.to_datetime(pred["datetime"])
pred = pred.set_index([pred.columns[0], pred.columns[1]])
pred = pred.iloc[:9000]

Просмотреть файл

@ -554,7 +554,7 @@ class AdaRNN(nn.Module):
return fc_out
class TransferLoss(object):
class TransferLoss:
def __init__(self, loss_type="cosine", input_dim=512):
"""
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv

Просмотреть файл

@ -98,7 +98,6 @@ class DNNModelPytorch(Model):
"\nlr_decay_steps : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\neval_steps : {}"
"\nseed : {}"
"\ndevice : {}"
"\nuse_GPU : {}"
@ -113,7 +112,6 @@ class DNNModelPytorch(Model):
lr_decay_steps,
optimizer,
loss,
eval_steps,
seed,
self.device,
self.use_gpu,
@ -331,8 +329,8 @@ class Net(nn.Module):
dnn_layers = []
drop_input = nn.Dropout(0.05)
dnn_layers.append(drop_input)
for i, (input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
fc = nn.Linear(input_dim, hidden_units)
for i, (_input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
fc = nn.Linear(_input_dim, hidden_units)
activation = nn.LeakyReLU(negative_slope=0.1, inplace=False)
bn = nn.BatchNorm1d(hidden_units)
seq = nn.Sequential(fc, bn, activation)

Просмотреть файл

@ -19,7 +19,7 @@ import torch.nn.functional as F
try:
from torch.utils.tensorboard import SummaryWriter
except:
except ImportError:
SummaryWriter = None
from tqdm import tqdm
@ -257,7 +257,7 @@ class TRAModel(Model):
total_loss += loss.item()
total_count += 1
if self.use_daily_transport and len(P_all):
if self.use_daily_transport and len(P_all) > 0:
P_all = pd.concat(P_all, axis=0)
prob_all = pd.concat(prob_all, axis=0)
choice_all = pd.concat(choice_all, axis=0)

Просмотреть файл

@ -15,7 +15,6 @@ from plotly.figure_factory import create_distplot
class BaseGraph:
""" """
_name = None
@ -297,8 +296,8 @@ class SubplotsGraph:
:return:
"""
self._sub_graph_data = list()
self._subplot_titles = list()
self._sub_graph_data = []
self._subplot_titles = []
for i, column_name in enumerate(self._df.columns):
row = math.ceil((i + 1) / self.__cols)

Просмотреть файл

@ -594,7 +594,7 @@ class TSDatasetH(DatasetH):
flt_kwargs = deepcopy(kwargs)
if flt_col is not None:
flt_kwargs["col_set"] = flt_col
flt_data = self._prepare_seg(ext_slice, **flt_kwargs)
flt_data = super()._prepare_seg(ext_slice, **flt_kwargs)
assert len(flt_data.columns) == 1
else:
flt_data = None

Просмотреть файл

@ -1407,14 +1407,14 @@ class PairRolling(ExpressionOps):
)
def get_extended_window_size(self):
ll, lr = self.feature_left.get_extended_window_size()
rl, rr = self.feature_right.get_extended_window_size()
if self.N == 0:
get_module_logger(self.__class__.__name__).warning(
"The PairRolling(ATTR, 0) will not be accurately calculated"
)
return self.feature.get_extended_window_size()
return -np.inf, max(lr, rr)
else:
ll, lr = self.feature_left.get_extended_window_size()
rl, rr = self.feature_right.get_extended_window_size()
return max(ll, rl) + self.N - 1, max(lr, rr)

Просмотреть файл

@ -120,7 +120,7 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
# If cache is enabled, then return cache directly
if self.enable_read_cache:
key = "orig_file" + str(self.uri)
if not key in H["c"]:
if key not in H["c"]:
H["c"][key] = self._read_calendar()
_calendar = H["c"][key]
else:

Просмотреть файл

@ -50,7 +50,7 @@ class StructuredCovEstimator(RiskModel):
num_factors (int): number of components to keep.
kwargs: see `RiskModel` for more information
"""
if "nan_option" in kwargs.keys():
if "nan_option" in kwargs:
assert kwargs["nan_option"] in [self.DEFAULT_NAN_OPTION], "nan_option={} is not supported".format(
kwargs["nan_option"]
)

Просмотреть файл

@ -254,21 +254,21 @@ class TrainerR(Trainer):
recs.append(rec)
return recs
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
def end_train(self, models: list, **kwargs) -> List[Recorder]:
"""
Set STATUS_END tag to the recorders.
Args:
recs (list): a list of trained recorders.
models (list): a list of trained recorders.
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(recs, Recorder):
recs = [recs]
for rec in recs:
if isinstance(models, Recorder):
models = [models]
for rec in models:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
return models
class DelayTrainerR(TrainerR):
@ -289,13 +289,13 @@ class DelayTrainerR(TrainerR):
self.end_train_func = end_train_func
self.delay = True
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
def end_train(self, models, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
Args:
recs (list): a list of Recorder, the tasks have been saved to them
models (list): a list of Recorder, the tasks have been saved to them
end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for end_train_func.
@ -303,18 +303,18 @@ class DelayTrainerR(TrainerR):
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(recs, Recorder):
recs = [recs]
if isinstance(models, Recorder):
models = [models]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
for rec in recs:
for rec in models:
if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END:
continue
end_train_func(rec, experiment_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
return models
class TrainerRM(Trainer):

Просмотреть файл

@ -2,6 +2,7 @@
# Licensed under the MIT License.
import re
import sys
import qlib
import shutil
import zipfile
@ -101,7 +102,7 @@ class GetData:
f"\nAre you sure you want to delete, yes(Y/y), no (N/n):"
)
if str(flag) not in ["Y", "y"]:
exit()
sys.exit()
for _p in rm_dirs:
logger.warning(f"delete: {_p}")
shutil.rmtree(_p)

Просмотреть файл

@ -654,16 +654,13 @@ def exists_qlib_data(qlib_dir):
def check_qlib_data(qlib_config):
inst_dir = Path(qlib_config["provider_uri"]).joinpath("instruments")
for _p in inst_dir.glob("*.txt"):
try:
assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, (
f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:"
f"\n\tIf you are using the data provided by qlib: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset"
f"\n\tIf you are using your own data, please dump the data again: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format"
)
except AssertionError:
raise
assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, (
f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:"
f"\n\tIf you are using the data provided by qlib: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset"
f"\n\tIf you are using your own data, please dump the data again: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format"
)
def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:

Просмотреть файл

@ -4,8 +4,7 @@
# Base exception class
class QlibException(Exception):
def __init__(self, message):
super(QlibException, self).__init__(message)
pass
class RecorderInitializationError(QlibException):

Просмотреть файл

@ -80,8 +80,7 @@ class AsyncCaller:
data = self._q.get()
if data == self.STOP_MARK:
break
else:
data()
data()
def __call__(self, func, *args, **kwargs):
self._q.put(partial(func, *args, **kwargs))

Просмотреть файл

@ -187,7 +187,7 @@ def resam_ts_data(
if isinstance(feature.index, pd.MultiIndex):
if callable(method):
method_func = method
return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs))
return feature.groupby(level="instrument").apply(method_func, **method_kwargs)
elif isinstance(method, str):
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
else:

Просмотреть файл

@ -416,6 +416,11 @@ class QlibRecorder:
# Case 5
recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d', experiment_name='test')
Here are some things users may concern
- Q: What recorder will it return if multiple recorder meets the query (e.g. query with experiment_name)
- A: If mlflow backend is used, then the recorder with the latest `start_time` will be returned. Because MLflow's `search_runs` function guarantee it
Parameters
----------
recorder_id : str

Просмотреть файл

@ -287,6 +287,9 @@ class MLflowExperiment(Experiment):
"""
Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will
raise errors.
Quoting docs of search_runs from MLflow
> The default ordering is to sort by start_time DESC, then run_id.
"""
assert (
recorder_id is not None or recorder_name is not None

Просмотреть файл

@ -355,7 +355,7 @@ class MLflowRecorder(Recorder):
shutil.rmtree(Path(path).absolute().parent)
return data
except Exception as e:
raise LoadObjectError(message=str(e))
raise LoadObjectError(str(e))
@AsyncCaller.async_dec(ac_attr="async_log")
def log_params(self, **kwargs):