зеркало из https://github.com/microsoft/qlib.git
135 строки
5.3 KiB
Python
135 строки
5.3 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import qlib
|
|
import fire
|
|
import pickle
|
|
|
|
from datetime import datetime
|
|
from qlib.constant import REG_CN
|
|
from qlib.data.dataset.handler import DataHandlerLP
|
|
from qlib.utils import init_instance_by_config
|
|
from qlib.tests.data import GetData
|
|
|
|
|
|
class RollingDataWorkflow:
|
|
MARKET = "csi300"
|
|
start_time = "2010-01-01"
|
|
end_time = "2019-12-31"
|
|
rolling_cnt = 5
|
|
|
|
def _init_qlib(self):
|
|
"""initialize qlib"""
|
|
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
|
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
|
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
|
|
|
def _dump_pre_handler(self, path):
|
|
handler_config = {
|
|
"class": "Alpha158",
|
|
"module_path": "qlib.contrib.data.handler",
|
|
"kwargs": {
|
|
"start_time": self.start_time,
|
|
"end_time": self.end_time,
|
|
"instruments": self.MARKET,
|
|
"infer_processors": [],
|
|
"learn_processors": [],
|
|
},
|
|
}
|
|
pre_handler = init_instance_by_config(handler_config)
|
|
pre_handler.config(dump_all=True)
|
|
pre_handler.to_pickle(path)
|
|
|
|
def _load_pre_handler(self, path):
|
|
with open(path, "rb") as file_dataset:
|
|
pre_handler = pickle.load(file_dataset)
|
|
return pre_handler
|
|
|
|
def rolling_process(self):
|
|
self._init_qlib()
|
|
self._dump_pre_handler("pre_handler.pkl")
|
|
pre_handler = self._load_pre_handler("pre_handler.pkl")
|
|
|
|
train_start_time = (2010, 1, 1)
|
|
train_end_time = (2012, 12, 31)
|
|
valid_start_time = (2013, 1, 1)
|
|
valid_end_time = (2013, 12, 31)
|
|
test_start_time = (2014, 1, 1)
|
|
test_end_time = (2014, 12, 31)
|
|
|
|
dataset_config = {
|
|
"class": "DatasetH",
|
|
"module_path": "qlib.data.dataset",
|
|
"kwargs": {
|
|
"handler": {
|
|
"class": "RollingDataHandler",
|
|
"module_path": "rolling_handler",
|
|
"kwargs": {
|
|
"start_time": datetime(*train_start_time),
|
|
"end_time": datetime(*test_end_time),
|
|
"fit_start_time": datetime(*train_start_time),
|
|
"fit_end_time": datetime(*train_end_time),
|
|
"infer_processors": [
|
|
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}},
|
|
],
|
|
"learn_processors": [
|
|
{"class": "DropnaLabel"},
|
|
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
|
],
|
|
"data_loader_kwargs": {
|
|
"handler_config": pre_handler,
|
|
},
|
|
},
|
|
},
|
|
"segments": {
|
|
"train": (datetime(*train_start_time), datetime(*train_end_time)),
|
|
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
|
|
"test": (datetime(*test_start_time), datetime(*test_end_time)),
|
|
},
|
|
},
|
|
}
|
|
|
|
dataset = init_instance_by_config(dataset_config)
|
|
|
|
for rolling_offset in range(self.rolling_cnt):
|
|
print(f"===========rolling{rolling_offset} start===========")
|
|
if rolling_offset:
|
|
dataset.config(
|
|
handler_kwargs={
|
|
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
|
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
|
"processor_kwargs": {
|
|
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
|
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
|
},
|
|
},
|
|
segments={
|
|
"train": (
|
|
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
|
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
|
),
|
|
"valid": (
|
|
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),
|
|
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),
|
|
),
|
|
"test": (
|
|
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),
|
|
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
|
),
|
|
},
|
|
)
|
|
dataset.setup_data(
|
|
handler_kwargs={
|
|
"init_type": DataHandlerLP.IT_FIT_SEQ,
|
|
}
|
|
)
|
|
|
|
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
|
|
print(dtrain, dvalid, dtest)
|
|
## print or dump data
|
|
print(f"===========rolling{rolling_offset} end===========")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(RollingDataWorkflow)
|