зеркало из https://github.com/microsoft/qlib.git
Ptnn4both datatypes and alignment tests (#1827)
* Init model for both dataset * Remove some deprecated code * Add model template; * We must align with previous results * We choose another mode as the initial version * Almost success to run GRU * Successfully run training * Passed general_nn test * gru test * Alignment test passed * comment * fix readme & minor errors * general nn updates & benchmarks * Update examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml --------- Co-authored-by: Young <afe.young@gmail.com> Co-authored-by: you-n-g <you-n-g@users.noreply.github.com>
This commit is contained in:
Родитель
2c33332dd6
Коммит
c9ed050ef0
|
@ -0,0 +1,19 @@
|
|||
|
||||
|
||||
# Introduction
|
||||
|
||||
What is GeneralPtNN
|
||||
- Fix previous design that fail to support both Time-series and tabular data
|
||||
- Now you can just replace the Pytorch model structure to run a NN model.
|
||||
|
||||
We provide an example to demonstrate the effectiveness of the current design.
|
||||
- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158-dataset)
|
||||
- `workflow_config_gru2mlp.yaml` to demonstrate we can convert config from time-series to tabular data with minimal changes
|
||||
- You only have to change the net & dataset class to make the conversion.
|
||||
- `workflow_config_mlp.yaml` achieved similar functionality with [MLP](../README.md#Alpha158-dataset)
|
||||
|
||||
# TODO
|
||||
|
||||
- We will align existing models to current design.
|
||||
|
||||
- The result of `workflow_config_mlp.yaml` is different with the result of [MLP](../README.md#Alpha158-dataset) since GeneralPtNN has a different stopping method compared to previous implementations. Specificly, GeneralPtNN controls training according to epoches, whereas previous methods controlled by max_steps.
|
|
@ -0,0 +1,100 @@
|
|||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
n_epochs: 200
|
||||
lr: 2e-4
|
||||
early_stop: 10
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
n_jobs: 20
|
||||
GPU: 0
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_gru_ts.GRUModel"
|
||||
pt_model_kwargs: {
|
||||
"d_feat": 20,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.,
|
||||
}
|
||||
dataset:
|
||||
class: TSDatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
step_len: 20
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
|
@ -0,0 +1,93 @@
|
|||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: FilterCol
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10",
|
||||
"ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5",
|
||||
"RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"
|
||||
]
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
lr: 1e-3
|
||||
n_epochs: 1
|
||||
batch_size: 800
|
||||
loss: mse
|
||||
optimizer: adam
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
|
||||
pt_model_kwargs:
|
||||
input_dim: 20
|
||||
layers: [20,]
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
|
@ -0,0 +1,98 @@
|
|||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "CSZFillna",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
}
|
||||
]
|
||||
learn_processors: [
|
||||
{
|
||||
"class" : "DropCol",
|
||||
"kwargs":{"col_list": ["VWAP0"]}
|
||||
},
|
||||
{
|
||||
"class" : "DropnaProcessor",
|
||||
"kwargs":{"fields_group": "feature"}
|
||||
},
|
||||
"DropnaLabel",
|
||||
{
|
||||
"class": "CSZScoreNorm",
|
||||
"kwargs": {"fields_group": "label"}
|
||||
}
|
||||
]
|
||||
process_type: "independent"
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal: <PRED>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: GeneralPTNN
|
||||
module_path: qlib.contrib.model.pytorch_general_nn
|
||||
kwargs:
|
||||
# FIXME: wrong parameters.
|
||||
lr: 2e-3
|
||||
batch_size: 8192
|
||||
loss: mse
|
||||
weight_decay: 0.0002
|
||||
optimizer: adam
|
||||
pt_model_uri: "qlib.contrib.model.pytorch_nn.Net"
|
||||
pt_model_kwargs:
|
||||
input_dim: 157
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
|
@ -0,0 +1,353 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...utils import (
|
||||
init_instance_by_config,
|
||||
get_or_create_path,
|
||||
)
|
||||
from ...log import get_module_logger
|
||||
|
||||
from ...model.utils import ConcatDataset
|
||||
|
||||
|
||||
class GeneralPTNN(Model):
|
||||
"""
|
||||
Motivation:
|
||||
We want to provide a Qlib General Pytorch Model Adaptor
|
||||
You can reuse it for all kinds of Pytorch models.
|
||||
It should include the training and predict process
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluation metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_epochs=200,
|
||||
lr=0.001,
|
||||
metric="",
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
weight_decay=0.0,
|
||||
optimizer="adam",
|
||||
n_jobs=10,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
|
||||
pt_model_kwargs={
|
||||
"d_feat": 6,
|
||||
"hidden_size": 64,
|
||||
"num_layers": 2,
|
||||
"dropout": 0.0,
|
||||
},
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("GeneralPTNN")
|
||||
self.logger.info("GeneralPTNN pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.n_epochs = n_epochs
|
||||
self.lr = lr
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.weight_decay = weight_decay
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.seed = seed
|
||||
|
||||
self.pt_model_uri, self.pt_model_kwargs = pt_model_uri, pt_model_kwargs
|
||||
self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs})
|
||||
|
||||
self.logger.info(
|
||||
"GeneralPTNN parameters setting:"
|
||||
"\nn_epochs : {}"
|
||||
"\nlr : {}"
|
||||
"\nmetric : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nweight_decay : {}"
|
||||
"\nseed : {}"
|
||||
"\npt_model_uri: {}"
|
||||
"\npt_model_kwargs: {}".format(
|
||||
n_epochs,
|
||||
lr,
|
||||
metric,
|
||||
batch_size,
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
weight_decay,
|
||||
seed,
|
||||
pt_model_uri,
|
||||
pt_model_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.logger.info("model:\n{:}".format(self.dnn_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)
|
||||
elif optimizer.lower() == "gd":
|
||||
self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.dnn_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label, weight):
|
||||
loss = weight * (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label, weight=None):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if weight is None:
|
||||
weight = torch.ones_like(label)
|
||||
|
||||
if self.loss == "mse":
|
||||
return self.mse(pred[mask], label[mask], weight[mask])
|
||||
|
||||
raise ValueError("unknown loss `%s`" % self.loss)
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def _get_fl(self, data: torch.Tensor):
|
||||
"""
|
||||
get feature and label from data
|
||||
- Handle the different data shape of time series and tabular data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : torch.Tensor
|
||||
input data which maybe 3 dimension or 2 dimension
|
||||
- 3dim: [batch_size, time_step, feature_dim]
|
||||
- 2dim: [batch_size, feature_dim]
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
if data.dim() == 3:
|
||||
# it is a time series dataset
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
elif data.dim() == 2:
|
||||
# it is a tabular dataset
|
||||
feature = data[:, 0:-1].to(self.device)
|
||||
label = data[:, -1].to(self.device)
|
||||
else:
|
||||
raise ValueError("Unsupported data shape.")
|
||||
return feature, label
|
||||
|
||||
def train_epoch(self, data_loader):
|
||||
self.dnn_model.train()
|
||||
|
||||
for data, weight in data_loader:
|
||||
feature, label = self._get_fl(data)
|
||||
|
||||
pred = self.dnn_model(feature.float())
|
||||
loss = self.loss_fn(pred, label, weight.to(self.device))
|
||||
|
||||
self.train_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.dnn_model.parameters(), 3.0)
|
||||
self.train_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_loader):
|
||||
self.dnn_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
for data, weight in data_loader:
|
||||
feature, label = self._get_fl(data)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.dnn_model(feature.float())
|
||||
loss = self.loss_fn(pred, label, weight.to(self.device))
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: Union[DatasetH, TSDatasetH],
|
||||
evals_result=dict(),
|
||||
save_path=None,
|
||||
reweighter=None,
|
||||
):
|
||||
ists = isinstance(dataset, TSDatasetH) # is this time series dataset
|
||||
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
if dl_train.empty or dl_valid.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
|
||||
if reweighter is None:
|
||||
wl_train = np.ones(len(dl_train))
|
||||
wl_valid = np.ones(len(dl_valid))
|
||||
elif isinstance(reweighter, Reweighter):
|
||||
wl_train = reweighter.reweight(dl_train)
|
||||
wl_valid = reweighter.reweight(dl_valid)
|
||||
else:
|
||||
raise ValueError("Unsupported reweighter type.")
|
||||
|
||||
# Preprocess for data. To align to Dataset Interface for DataLoader
|
||||
if ists:
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
else:
|
||||
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
|
||||
dl_train = dl_train.values
|
||||
dl_valid = dl_valid.values
|
||||
|
||||
train_loader = DataLoader(
|
||||
ConcatDataset(dl_train, wl_train),
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=self.n_jobs,
|
||||
drop_last=True,
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
ConcatDataset(dl_valid, wl_valid),
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.n_jobs,
|
||||
drop_last=True,
|
||||
)
|
||||
del dl_train, dl_valid, wl_train, wl_valid
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
|
||||
for step in range(self.n_epochs):
|
||||
self.logger.info("Epoch%d:", step)
|
||||
self.logger.info("training...")
|
||||
self.train_epoch(train_loader)
|
||||
self.logger.info("evaluating...")
|
||||
train_loss, train_score = self.test_epoch(train_loader)
|
||||
val_loss, val_score = self.test_epoch(valid_loader)
|
||||
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if step == 0:
|
||||
best_param = copy.deepcopy(self.dnn_model.state_dict())
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = step
|
||||
best_param = copy.deepcopy(self.dnn_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.dnn_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset: Union[DatasetH, TSDatasetH]):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
|
||||
if isinstance(dataset, TSDatasetH):
|
||||
dl_test.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
index = dl_test.get_index()
|
||||
else:
|
||||
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
|
||||
index = dl_test.index
|
||||
dl_test = dl_test.values
|
||||
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.dnn_model.eval()
|
||||
preds = []
|
||||
|
||||
for data in test_loader:
|
||||
feature, _ = self._get_fl(data)
|
||||
feature = feature.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.dnn_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
preds_concat = np.concatenate(preds)
|
||||
if preds_concat.ndim != 1:
|
||||
preds_concat = preds_concat.ravel()
|
||||
|
||||
return pd.Series(preds_concat, index=index)
|
|
@ -317,7 +317,6 @@ class GRU(Model):
|
|||
|
||||
|
||||
class GRUModel(nn.Module):
|
||||
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
import unittest
|
||||
from qlib.tests import TestAutoData
|
||||
|
||||
|
||||
class TestNN(TestAutoData):
|
||||
def test_both_dataset(self):
|
||||
try:
|
||||
from qlib.contrib.model.pytorch_general_nn import GeneralPTNN
|
||||
from qlib.data.dataset import DatasetH, TSDatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
except ImportError:
|
||||
print("Import error.")
|
||||
return
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"instruments": "csi300",
|
||||
"data_loader": {
|
||||
"class": "QlibDataLoader", # Assuming QlibDataLoader is a string reference to the class
|
||||
"kwargs": {
|
||||
"config": {
|
||||
"feature": [["$high", "$close", "$low"], ["H", "C", "L"]],
|
||||
"label": [["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]],
|
||||
},
|
||||
"freq": "day",
|
||||
},
|
||||
},
|
||||
# TODO: processors
|
||||
"learn_processors": [
|
||||
{
|
||||
"class": "DropnaLabel",
|
||||
},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
],
|
||||
}
|
||||
segments = {
|
||||
"train": ["2008-01-01", "2014-12-31"],
|
||||
"valid": ["2015-01-01", "2016-12-31"],
|
||||
"test": ["2017-01-01", "2020-08-01"],
|
||||
}
|
||||
data_handler = DataHandlerLP(**data_handler_config)
|
||||
|
||||
# time-series dataset
|
||||
tsds = TSDatasetH(handler=data_handler, segments=segments)
|
||||
|
||||
# tabular dataset
|
||||
tbds = DatasetH(handler=data_handler, segments=segments)
|
||||
|
||||
model_l = [
|
||||
GeneralPTNN(
|
||||
n_epochs=2,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel",
|
||||
pt_model_kwargs={
|
||||
"d_feat": 3,
|
||||
"hidden_size": 8,
|
||||
"num_layers": 1,
|
||||
"dropout": 0.0,
|
||||
},
|
||||
),
|
||||
GeneralPTNN(
|
||||
n_epochs=2,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
|
||||
pt_model_kwargs={
|
||||
"input_dim": 3,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
for ds, model in list(zip((tsds, tbds), model_l)):
|
||||
model.fit(ds) # It works
|
||||
model.predict(ds) # It works
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче