зеркало из https://github.com/microsoft/qlib.git
pylint code refine & Fix nested example (#848)
* refine code by CI * fix argument error * fix nested eample
This commit is contained in:
Родитель
c3996955ef
Коммит
d0113ea7df
|
@ -155,6 +155,8 @@ class NestedDecisionExecutionWorkflow:
|
|||
},
|
||||
}
|
||||
|
||||
exp_name = "nested"
|
||||
|
||||
port_analysis_config = {
|
||||
"executor": {
|
||||
"class": "NestedExecutor",
|
||||
|
@ -230,7 +232,7 @@ class NestedDecisionExecutionWorkflow:
|
|||
qlib.init(provider_uri=provider_uri_map, dataset_cache=None, expression_cache=None)
|
||||
|
||||
def _train_model(self, model, dataset):
|
||||
with R.start(experiment_name="train"):
|
||||
with R.start(experiment_name=self.exp_name):
|
||||
R.log_params(**flatten_dict(self.task))
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
|
@ -257,7 +259,7 @@ class NestedDecisionExecutionWorkflow:
|
|||
self.port_analysis_config["strategy"] = strategy_config
|
||||
self.port_analysis_config["backtest"]["benchmark"] = self.benchmark
|
||||
|
||||
with R.start(experiment_name="backtest"):
|
||||
with R.start(experiment_name=self.exp_name, resume=True):
|
||||
recorder = R.get_recorder()
|
||||
par = PortAnaRecord(
|
||||
recorder,
|
||||
|
@ -382,7 +384,7 @@ class NestedDecisionExecutionWorkflow:
|
|||
}
|
||||
pa_conf["backtest"]["benchmark"] = self.benchmark
|
||||
|
||||
with R.start(experiment_name="backtest"):
|
||||
with R.start(experiment_name=self.exp_name, resume=True):
|
||||
recorder = R.get_recorder()
|
||||
par = PortAnaRecord(recorder, pa_conf)
|
||||
par.generate()
|
||||
|
|
|
@ -536,7 +536,7 @@ class Exchange:
|
|||
deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)
|
||||
if deal_amount == 0:
|
||||
continue
|
||||
elif deal_amount > 0:
|
||||
if deal_amount > 0:
|
||||
# buy stock
|
||||
buy_order_list.append(
|
||||
Order(
|
||||
|
@ -687,9 +687,7 @@ class Exchange:
|
|||
orig_deal_amount = order.deal_amount
|
||||
order.deal_amount = max(min(vol_limit_min, orig_deal_amount), 0)
|
||||
if vol_limit_min < orig_deal_amount:
|
||||
self.logger.debug(
|
||||
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
|
||||
)
|
||||
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")
|
||||
|
||||
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
|
||||
"""return the real order amount after cash limit for buying.
|
||||
|
|
|
@ -194,7 +194,7 @@ class BaseExecutor:
|
|||
return return_value.get("execute_result")
|
||||
|
||||
@abstractclassmethod
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
"""
|
||||
Please refer to the doc of collect_data
|
||||
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
|
||||
|
|
|
@ -20,7 +20,7 @@ class BasePosition:
|
|||
Please refer to the `Position` class for the position
|
||||
"""
|
||||
|
||||
def __init__(self, cash=0.0, *args, **kwargs):
|
||||
def __init__(self, *args, cash=0.0, **kwargs):
|
||||
self._settle_type = self.ST_NO
|
||||
|
||||
def skip_update(self) -> bool:
|
||||
|
|
|
@ -156,16 +156,16 @@ def decompose_portofolio(stock_weight_df, stock_group_df, stock_ret_df):
|
|||
group_weight, stock_weight_in_group = decompose_portofolio_weight(stock_weight_df, stock_group_df)
|
||||
|
||||
group_ret = {}
|
||||
for group_key in stock_weight_in_group:
|
||||
stock_weight_in_group_start_date = min(stock_weight_in_group[group_key].index)
|
||||
stock_weight_in_group_end_date = max(stock_weight_in_group[group_key].index)
|
||||
for group_key, val in stock_weight_in_group.items():
|
||||
stock_weight_in_group_start_date = min(val.index)
|
||||
stock_weight_in_group_end_date = max(val.index)
|
||||
|
||||
temp_stock_ret_df = stock_ret_df[
|
||||
(stock_ret_df.index >= stock_weight_in_group_start_date)
|
||||
& (stock_ret_df.index <= stock_weight_in_group_end_date)
|
||||
]
|
||||
|
||||
group_ret[group_key] = (temp_stock_ret_df * stock_weight_in_group[group_key]).sum(axis=1)
|
||||
group_ret[group_key] = (temp_stock_ret_df * val).sum(axis=1)
|
||||
# If no weight is assigned, then the return of group will be np.nan
|
||||
group_ret[group_key][group_weight[group_key] == 0.0] = np.nan
|
||||
|
||||
|
|
|
@ -212,7 +212,8 @@ class PortfolioMetrics:
|
|||
path: str/ pathlib.Path()
|
||||
"""
|
||||
path = pathlib.Path(path)
|
||||
r = pd.read_csv(open(path, "rb"), index_col=0)
|
||||
with path.open("rb") as f:
|
||||
r = pd.read_csv(f, index_col=0)
|
||||
r.index = pd.DatetimeIndex(r.index)
|
||||
|
||||
index = r.index
|
||||
|
|
|
@ -205,10 +205,7 @@ class BaseInfrastructure:
|
|||
warnings.warn(f"infra {infra_name} is not found!")
|
||||
|
||||
def has(self, infra_name):
|
||||
if infra_name in self.get_support_infra() and hasattr(self, infra_name):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return infra_name in self.get_support_infra() and hasattr(self, infra_name)
|
||||
|
||||
def update(self, other):
|
||||
support_infra = other.get_support_infra()
|
||||
|
|
|
@ -63,9 +63,7 @@ def _get_date_parse_fn(target):
|
|||
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
|
||||
get_date_parse_fn(20120101)('2017-01-01') => 20170101
|
||||
"""
|
||||
if isinstance(target, pd.Timestamp):
|
||||
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
|
||||
elif isinstance(target, int):
|
||||
if isinstance(target, int):
|
||||
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
|
||||
elif isinstance(target, str) and len(target) == 8:
|
||||
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
|
||||
|
@ -158,7 +156,7 @@ class MTSDatasetH(DatasetH):
|
|||
try:
|
||||
df = self.handler._learn.copy() # use copy otherwise recorder will fail
|
||||
# FIXME: currently we cannot support switching from `_learn` to `_infer` for inference
|
||||
except:
|
||||
except Exception:
|
||||
warnings.warn("cannot access `_learn`, will load raw data")
|
||||
df = self.handler._data.copy()
|
||||
df.index = df.index.swaplevel()
|
||||
|
|
|
@ -371,7 +371,7 @@ def long_short_backtest(
|
|||
|
||||
def t_run():
|
||||
pred_FN = "./check_pred.csv"
|
||||
pred = pd.read_csv(pred_FN)
|
||||
pred: pd.DataFrame = pd.read_csv(pred_FN)
|
||||
pred["datetime"] = pd.to_datetime(pred["datetime"])
|
||||
pred = pred.set_index([pred.columns[0], pred.columns[1]])
|
||||
pred = pred.iloc[:9000]
|
||||
|
|
|
@ -554,7 +554,7 @@ class AdaRNN(nn.Module):
|
|||
return fc_out
|
||||
|
||||
|
||||
class TransferLoss(object):
|
||||
class TransferLoss:
|
||||
def __init__(self, loss_type="cosine", input_dim=512):
|
||||
"""
|
||||
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
|
||||
|
|
|
@ -98,7 +98,6 @@ class DNNModelPytorch(Model):
|
|||
"\nlr_decay_steps : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\neval_steps : {}"
|
||||
"\nseed : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
|
@ -113,7 +112,6 @@ class DNNModelPytorch(Model):
|
|||
lr_decay_steps,
|
||||
optimizer,
|
||||
loss,
|
||||
eval_steps,
|
||||
seed,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
|
@ -331,8 +329,8 @@ class Net(nn.Module):
|
|||
dnn_layers = []
|
||||
drop_input = nn.Dropout(0.05)
|
||||
dnn_layers.append(drop_input)
|
||||
for i, (input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
|
||||
fc = nn.Linear(input_dim, hidden_units)
|
||||
for i, (_input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
|
||||
fc = nn.Linear(_input_dim, hidden_units)
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=False)
|
||||
bn = nn.BatchNorm1d(hidden_units)
|
||||
seq = nn.Sequential(fc, bn, activation)
|
||||
|
|
|
@ -19,7 +19,7 @@ import torch.nn.functional as F
|
|||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except:
|
||||
except ImportError:
|
||||
SummaryWriter = None
|
||||
|
||||
from tqdm import tqdm
|
||||
|
@ -257,7 +257,7 @@ class TRAModel(Model):
|
|||
total_loss += loss.item()
|
||||
total_count += 1
|
||||
|
||||
if self.use_daily_transport and len(P_all):
|
||||
if self.use_daily_transport and len(P_all) > 0:
|
||||
P_all = pd.concat(P_all, axis=0)
|
||||
prob_all = pd.concat(prob_all, axis=0)
|
||||
choice_all = pd.concat(choice_all, axis=0)
|
||||
|
|
|
@ -15,7 +15,6 @@ from plotly.figure_factory import create_distplot
|
|||
|
||||
|
||||
class BaseGraph:
|
||||
""" """
|
||||
|
||||
_name = None
|
||||
|
||||
|
@ -297,8 +296,8 @@ class SubplotsGraph:
|
|||
|
||||
:return:
|
||||
"""
|
||||
self._sub_graph_data = list()
|
||||
self._subplot_titles = list()
|
||||
self._sub_graph_data = []
|
||||
self._subplot_titles = []
|
||||
|
||||
for i, column_name in enumerate(self._df.columns):
|
||||
row = math.ceil((i + 1) / self.__cols)
|
||||
|
|
|
@ -594,7 +594,7 @@ class TSDatasetH(DatasetH):
|
|||
flt_kwargs = deepcopy(kwargs)
|
||||
if flt_col is not None:
|
||||
flt_kwargs["col_set"] = flt_col
|
||||
flt_data = self._prepare_seg(ext_slice, **flt_kwargs)
|
||||
flt_data = super()._prepare_seg(ext_slice, **flt_kwargs)
|
||||
assert len(flt_data.columns) == 1
|
||||
else:
|
||||
flt_data = None
|
||||
|
|
|
@ -1407,14 +1407,14 @@ class PairRolling(ExpressionOps):
|
|||
)
|
||||
|
||||
def get_extended_window_size(self):
|
||||
ll, lr = self.feature_left.get_extended_window_size()
|
||||
rl, rr = self.feature_right.get_extended_window_size()
|
||||
if self.N == 0:
|
||||
get_module_logger(self.__class__.__name__).warning(
|
||||
"The PairRolling(ATTR, 0) will not be accurately calculated"
|
||||
)
|
||||
return self.feature.get_extended_window_size()
|
||||
return -np.inf, max(lr, rr)
|
||||
else:
|
||||
ll, lr = self.feature_left.get_extended_window_size()
|
||||
rl, rr = self.feature_right.get_extended_window_size()
|
||||
return max(ll, rl) + self.N - 1, max(lr, rr)
|
||||
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
|||
# If cache is enabled, then return cache directly
|
||||
if self.enable_read_cache:
|
||||
key = "orig_file" + str(self.uri)
|
||||
if not key in H["c"]:
|
||||
if key not in H["c"]:
|
||||
H["c"][key] = self._read_calendar()
|
||||
_calendar = H["c"][key]
|
||||
else:
|
||||
|
|
|
@ -50,7 +50,7 @@ class StructuredCovEstimator(RiskModel):
|
|||
num_factors (int): number of components to keep.
|
||||
kwargs: see `RiskModel` for more information
|
||||
"""
|
||||
if "nan_option" in kwargs.keys():
|
||||
if "nan_option" in kwargs:
|
||||
assert kwargs["nan_option"] in [self.DEFAULT_NAN_OPTION], "nan_option={} is not supported".format(
|
||||
kwargs["nan_option"]
|
||||
)
|
||||
|
|
|
@ -254,21 +254,21 @@ class TrainerR(Trainer):
|
|||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
|
||||
def end_train(self, models: list, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Set STATUS_END tag to the recorders.
|
||||
|
||||
Args:
|
||||
recs (list): a list of trained recorders.
|
||||
models (list): a list of trained recorders.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
for rec in recs:
|
||||
if isinstance(models, Recorder):
|
||||
models = [models]
|
||||
for rec in models:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
return models
|
||||
|
||||
|
||||
class DelayTrainerR(TrainerR):
|
||||
|
@ -289,13 +289,13 @@ class DelayTrainerR(TrainerR):
|
|||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
|
||||
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
def end_train(self, models, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them
|
||||
models (list): a list of Recorder, the tasks have been saved to them
|
||||
end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
kwargs: the params for end_train_func.
|
||||
|
@ -303,18 +303,18 @@ class DelayTrainerR(TrainerR):
|
|||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
if isinstance(models, Recorder):
|
||||
models = [models]
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
for rec in recs:
|
||||
for rec in models:
|
||||
if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END:
|
||||
continue
|
||||
end_train_func(rec, experiment_name, **kwargs)
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
return models
|
||||
|
||||
|
||||
class TrainerRM(Trainer):
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import sys
|
||||
import qlib
|
||||
import shutil
|
||||
import zipfile
|
||||
|
@ -101,7 +102,7 @@ class GetData:
|
|||
f"\nAre you sure you want to delete, yes(Y/y), no (N/n):"
|
||||
)
|
||||
if str(flag) not in ["Y", "y"]:
|
||||
exit()
|
||||
sys.exit()
|
||||
for _p in rm_dirs:
|
||||
logger.warning(f"delete: {_p}")
|
||||
shutil.rmtree(_p)
|
||||
|
|
|
@ -654,16 +654,13 @@ def exists_qlib_data(qlib_dir):
|
|||
def check_qlib_data(qlib_config):
|
||||
inst_dir = Path(qlib_config["provider_uri"]).joinpath("instruments")
|
||||
for _p in inst_dir.glob("*.txt"):
|
||||
try:
|
||||
assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, (
|
||||
f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:"
|
||||
f"\n\tIf you are using the data provided by qlib: "
|
||||
f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset"
|
||||
f"\n\tIf you are using your own data, please dump the data again: "
|
||||
f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format"
|
||||
)
|
||||
except AssertionError:
|
||||
raise
|
||||
assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, (
|
||||
f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:"
|
||||
f"\n\tIf you are using the data provided by qlib: "
|
||||
f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset"
|
||||
f"\n\tIf you are using your own data, please dump the data again: "
|
||||
f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format"
|
||||
)
|
||||
|
||||
|
||||
def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
|
||||
|
|
|
@ -4,8 +4,7 @@
|
|||
|
||||
# Base exception class
|
||||
class QlibException(Exception):
|
||||
def __init__(self, message):
|
||||
super(QlibException, self).__init__(message)
|
||||
pass
|
||||
|
||||
|
||||
class RecorderInitializationError(QlibException):
|
||||
|
|
|
@ -80,8 +80,7 @@ class AsyncCaller:
|
|||
data = self._q.get()
|
||||
if data == self.STOP_MARK:
|
||||
break
|
||||
else:
|
||||
data()
|
||||
data()
|
||||
|
||||
def __call__(self, func, *args, **kwargs):
|
||||
self._q.put(partial(func, *args, **kwargs))
|
||||
|
|
|
@ -187,7 +187,7 @@ def resam_ts_data(
|
|||
if isinstance(feature.index, pd.MultiIndex):
|
||||
if callable(method):
|
||||
method_func = method
|
||||
return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs))
|
||||
return feature.groupby(level="instrument").apply(method_func, **method_kwargs)
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
|
||||
else:
|
||||
|
|
|
@ -416,6 +416,11 @@ class QlibRecorder:
|
|||
# Case 5
|
||||
recorder = R.get_recorder(recorder_id='2e7a4efd66574fa49039e00ffaefa99d', experiment_name='test')
|
||||
|
||||
|
||||
Here are some things users may concern
|
||||
- Q: What recorder will it return if multiple recorder meets the query (e.g. query with experiment_name)
|
||||
- A: If mlflow backend is used, then the recorder with the latest `start_time` will be returned. Because MLflow's `search_runs` function guarantee it
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorder_id : str
|
||||
|
|
|
@ -287,6 +287,9 @@ class MLflowExperiment(Experiment):
|
|||
"""
|
||||
Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will
|
||||
raise errors.
|
||||
|
||||
Quoting docs of search_runs from MLflow
|
||||
> The default ordering is to sort by start_time DESC, then run_id.
|
||||
"""
|
||||
assert (
|
||||
recorder_id is not None or recorder_name is not None
|
||||
|
|
|
@ -355,7 +355,7 @@ class MLflowRecorder(Recorder):
|
|||
shutil.rmtree(Path(path).absolute().parent)
|
||||
return data
|
||||
except Exception as e:
|
||||
raise LoadObjectError(message=str(e))
|
||||
raise LoadObjectError(str(e))
|
||||
|
||||
@AsyncCaller.async_dec(ac_attr="async_log")
|
||||
def log_params(self, **kwargs):
|
||||
|
|
Загрузка…
Ссылка в новой задаче