* Refine previous version RL codes

* Polish utils/__init__.py

* Draft

* Use | instead of Union

* Simulator & action interpreter

* Test passed

* Migrate to SAOEState & new qlib interpreter

* Black format

* . Revert file_storage change

* Refactor file structure & renaming functions

* Enrich test cases

* Add QlibIntradayBacktestData

* Test interpreter

* Black format

* .

.

.

* Rename receive_execute_result()

* Use indicator to simplify state update

* Format code

* Modify data path

* Adjust file structure

* Minor change

* Add copyright message

* Format code

* Rename util functions

* Add CI

* Pylint issue

* Remove useless code to pass pylint

* Pass mypy

* Mypy issue

* mypy issue

* mypy issue

* Revert "mypy issue"

This reverts commit 8eb1b0174e.

* mypy issue

* mypy issue

* Fix the numpy version incompatible bug

* Fix a minor typing issue

* Try to skip python 3.7 test for qlib simulator

* Resolve PR comments by Yuge; solve several CI issues.

* Black issue

* Fix a low-level type error

* Change data name

* Resolve PR comments. Leave TODOs in the code base.

Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
Huoran Li 2022-08-01 09:56:07 +08:00 коммит произвёл GitHub
Родитель 687edd79d0
Коммит 2752bdc92c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
35 изменённых файлов: 1305 добавлений и 257 удалений

Просмотреть файл

@ -42,7 +42,7 @@ def get_exchange(
close_cost: float = 0.0025,
min_cost: float = 5.0,
limit_threshold: Union[Tuple[str, str], float, None] = None,
deal_price: Union[str, Tuple[str], List[str]] = None,
deal_price: Union[str, Tuple[str, str], List[str]] = None,
**kwargs: Any,
) -> Exchange:
"""get_exchange
@ -70,10 +70,10 @@ def get_exchange(
min_cost : float
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
deal_price: Union[str, Tuple[str], List[str]]
deal_price: Union[str, Tuple[str, str], List[str]]
The `deal_price` supports following two types of input
- <deal_price> : str
- (<buy_price>, <sell_price>): Tuple[str] or List[str]
- (<buy_price>, <sell_price>): Tuple[str, str] or List[str]
<deal_price>, <buy_price> or <sell_price> := <price>
<price> := str

Просмотреть файл

@ -4,10 +4,11 @@
from __future__ import annotations
from abc import abstractmethod
from datetime import time
from enum import IntEnum
# try to fix circular imports when enabling type hints
from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, ClassVar, Generic, List, Optional, Tuple, TypeVar, Union, cast
from qlib.backtest.utils import TradeCalendarManager
from qlib.data.data import Cal
@ -23,7 +24,6 @@ from dataclasses import dataclass
import numpy as np
import pandas as pd
DecisionType = TypeVar("DecisionType")
@ -182,8 +182,8 @@ class OrderHelper:
return Order(
stock_id=code,
amount=amount,
start_time=start_time if start_time is not None else pd.Timestamp(start_time),
end_time=end_time if end_time is not None else pd.Timestamp(end_time),
start_time=None if start_time is None else pd.Timestamp(start_time),
end_time=None if end_time is None else pd.Timestamp(end_time),
direction=direction,
)
@ -249,7 +249,7 @@ class IdxTradeRange(TradeRange):
class TradeRangeByTime(TradeRange):
"""This is a helper function for make decisions"""
def __init__(self, start_time: str, end_time: str) -> None:
def __init__(self, start_time: str | time, end_time: str | time) -> None:
"""
This is a callable class.
@ -259,13 +259,13 @@ class TradeRangeByTime(TradeRange):
Parameters
----------
start_time : str
start_time : str | time
e.g. "9:30"
end_time : str
end_time : str | time
e.g. "14:30"
"""
self.start_time = pd.Timestamp(start_time).time()
self.end_time = pd.Timestamp(end_time).time()
self.start_time = pd.Timestamp(start_time).time() if isinstance(start_time, str) else start_time
self.end_time = pd.Timestamp(end_time).time() if isinstance(end_time, str) else end_time
assert self.start_time < self.end_time
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
@ -535,7 +535,12 @@ class TradeDecisionWO(BaseTradeDecision[Order]):
Besides, the time_range is also included.
"""
def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None:
def __init__(
self,
order_list: List[Order],
strategy: BaseStrategy,
trade_range: Union[Tuple[int, int], TradeRange] = None,
) -> None:
super().__init__(strategy, trade_range=trade_range)
self.order_list = cast(List[Order], order_list)
start, end = strategy.trade_calendar.get_step_time()

Просмотреть файл

@ -32,7 +32,7 @@ class Exchange:
start_time: Union[pd.Timestamp, str] = None,
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
deal_price: Union[str, Tuple[str], List[str]] = None,
deal_price: Union[str, Tuple[str, str], List[str]] = None,
subscribe_fields: list = [],
limit_threshold: Union[Tuple[str, str], float, None] = None,
volume_threshold: Union[tuple, dict] = None,
@ -448,9 +448,9 @@ class Exchange:
start_time: pd.Timestamp,
end_time: pd.Timestamp,
method: Optional[str] = "sum",
) -> float:
) -> Union[None, int, float, bool, IndexData]:
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method))
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
def get_deal_price(
self,
@ -459,7 +459,7 @@ class Exchange:
end_time: pd.Timestamp,
direction: OrderDir,
method: Optional[str] = "ts_data_last",
) -> float:
) -> Union[None, int, float, bool, IndexData]:
if direction == OrderDir.SELL:
pstr = self.sell_price
elif direction == OrderDir.BUY:
@ -472,7 +472,7 @@ class Exchange:
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
deal_price = self.get_close(stock_id, start_time, end_time, method)
return cast(float, deal_price)
return deal_price
def get_factor(
self,
@ -832,8 +832,11 @@ class Exchange:
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
:return: trade_price, trade_val, trade_cost
"""
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price
trade_price = cast(
float,
self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction),
)
total_trade_val = cast(float, self.get_volume(order.stock_id, order.start_time, order.end_time)) * trade_price
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
order.deal_amount = order.amount # set to full amount and clip it step by step
# Clipping amount first

Просмотреть файл

@ -484,6 +484,7 @@ class NestedExecutor(BaseExecutor):
inner_exe_res :
the execution result of inner task
"""
self.inner_strategy.post_exe_step(inner_exe_res)
def get_all_executors(self) -> List[BaseExecutor]:
"""get all executors, including self and inner_executor.get_all_executors()"""

Просмотреть файл

@ -102,11 +102,22 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
self._freq_file_cache = freq
return self._freq_file_cache
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
def _read_calendar(self) -> List[CalVT]:
# NOTE:
# if we want to accelerate partial reading calendar
# we can add parameters like `skip_rows: int = 0, n_rows: int = None` to the interface.
# Currently, it is not supported for the txt-based calendar
if not self.uri.exists():
self._write_calendar(values=[])
with self.uri.open("rb") as fp:
return [str(x) for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8")]
with self.uri.open("r") as fp:
res = []
for line in fp.readlines():
line = line.strip()
if len(line) > 0:
res.append(line)
return res
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
with self.uri.open(mode=mode) as fp:

Просмотреть файл

@ -3,7 +3,7 @@
from __future__ import annotations
from typing import Generic, TYPE_CHECKING, TypeVar
from typing import Optional, TYPE_CHECKING, Generic, TypeVar
from qlib.typehint import final
@ -21,7 +21,7 @@ AuxInfoType = TypeVar("AuxInfoType")
class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
"""Override this class to collect customized auxiliary information from environment."""
env: EnvWrapper | None = None
env: Optional[EnvWrapper] = None
@final
def __call__(self, simulator_state: StateType) -> AuxInfoType:

Просмотреть файл

@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import cast
import pandas as pd
from qlib.backtest import Exchange, Order
from .pickle_styled import IntradayBacktestData
class QlibIntradayBacktestData(IntradayBacktestData):
"""Backtest data for Qlib simulator"""
def __init__(self, order: Order, exchange: Exchange, start_time: pd.Timestamp, end_time: pd.Timestamp) -> None:
super(QlibIntradayBacktestData, self).__init__()
self._order = order
self._exchange = exchange
self._start_time = start_time
self._end_time = end_time
self._deal_price = cast(
pd.Series,
self._exchange.get_deal_price(
self._order.stock_id,
self._start_time,
self._end_time,
direction=self._order.direction,
method=None,
),
)
self._volume = cast(
pd.Series,
self._exchange.get_volume(
self._order.stock_id,
self._start_time,
self._end_time,
method=None,
),
)
def __repr__(self) -> str:
return (
f"Order: {self._order}, Exchange: {self._exchange}, "
f"Start time: {self._start_time}, End time: {self._end_time}"
)
def __len__(self) -> int:
return len(self._deal_price)
def get_deal_price(self) -> pd.Series:
return self._deal_price
def get_volume(self) -> pd.Series:
return self._volume
def get_time_index(self) -> pd.DatetimeIndex:
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])

Просмотреть файл

@ -19,19 +19,19 @@ This file shows resemblence to qlib.backtest.high_performance_ds. We might merge
from __future__ import annotations
from abc import abstractmethod
from functools import lru_cache
from typing import List, Sequence, cast
from pathlib import Path
from typing import List, Sequence, cast
import cachetools
import numpy as np
import pandas as pd
from cachetools.keys import hashkey
from qlib.backtest.decision import OrderDir, Order
from qlib.backtest.decision import Order, OrderDir
from qlib.typehint import Literal
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
"""Several ad-hoc deal price.
``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.
@ -40,7 +40,7 @@ DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
"""
def _infer_processed_data_column_names(shape: int) -> list[str]:
def _infer_processed_data_column_names(shape: int) -> List[str]:
if shape == 16:
return [
"$open",
@ -87,7 +87,36 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
class IntradayBacktestData:
"""Raw market data that is often used in backtesting (thus called BacktestData)."""
"""
Raw market data that is often used in backtesting (thus called BacktestData).
Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest
data type.
"""
@abstractmethod
def __repr__(self) -> str:
raise NotImplementedError
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError
@abstractmethod
def get_deal_price(self) -> pd.Series:
raise NotImplementedError
@abstractmethod
def get_volume(self) -> pd.Series:
raise NotImplementedError
@abstractmethod
def get_time_index(self) -> pd.DatetimeIndex:
raise NotImplementedError
class SimpleIntradayBacktestData(IntradayBacktestData):
"""Backtest data for simple simulator"""
def __init__(
self,
@ -95,8 +124,10 @@ class IntradayBacktestData:
stock_id: str,
date: pd.Timestamp,
deal_price: DealPriceType = "close",
order_dir: int | None = None,
):
order_dir: int = None,
) -> None:
super(SimpleIntradayBacktestData, self).__init__()
backtest = _read_pickle(data_dir / stock_id)
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
@ -105,13 +136,13 @@ class IntradayBacktestData:
self.data: pd.DataFrame = backtest
self.deal_price_type: DealPriceType = deal_price
self.order_dir: int | None = order_dir
self.order_dir = order_dir
def __repr__(self):
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.data})"
def __len__(self):
def __len__(self) -> int:
return len(self.data)
def get_deal_price(self) -> pd.Series:
@ -162,7 +193,14 @@ class IntradayProcessedData:
"""Processed data for "yesterday".
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index):
def __init__(
self,
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
) -> None:
proc = _read_pickle(data_dir / stock_id)
# We have to infer the names here because,
# unfortunately they are not included in the original data.
@ -190,16 +228,20 @@ class IntradayProcessedData:
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
assert len(self.today) == len(self.yesterday) == time_length
def __repr__(self):
def __repr__(self) -> str:
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
@lru_cache(maxsize=100) # 100 * 50K = 5MB
def load_intraday_backtest_data(
data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None
) -> IntradayBacktestData:
return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
def load_simple_intraday_backtest_data(
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
deal_price: DealPriceType = "close",
order_dir: int = None,
) -> SimpleIntradayBacktestData:
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
@cachetools.cached( # type: ignore
@ -207,13 +249,19 @@ def load_intraday_backtest_data(
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
)
def load_intraday_processed_data(
data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index
data_dir: Path,
stock_id: str,
date: pd.Timestamp,
feature_dim: int,
time_index: pd.Index,
) -> IntradayProcessedData:
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
def load_orders(
order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None
order_path: Path,
start_time: pd.Timestamp = None,
end_time: pd.Timestamp = None,
) -> Sequence[Order]:
"""Load orders, and set start time and end time for the orders."""
@ -251,7 +299,7 @@ def load_orders(
OrderDir(int(row["order_type"])),
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
)
),
)
return orders

Просмотреть файл

@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# TODO: find a better way to organize contents under this module.

Просмотреть файл

@ -0,0 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
# TODO: In the future we should merge the dataclass-based config with Qlib's dict-based config.
@dataclass
class ExchangeConfig:
limit_threshold: Union[float, Tuple[str, str]]
deal_price: Union[str, Tuple[str, str]]
volume_threshold: dict
open_cost: float = 0.0005
close_cost: float = 0.0015
min_cost: float = 5.0
trade_unit: Optional[float] = 100.0
cash_limit: Optional[Union[Path, float]] = None
generate_report: bool = False

Просмотреть файл

@ -0,0 +1,109 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import collections
from typing import List, Optional
import pandas as pd
import qlib
from qlib.config import REG_CN
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
from qlib.data.dataset import DatasetH
class LRUCache:
def __init__(self, pool_size: int = 200):
self.pool_size = pool_size
self.contents: dict = {}
self.keys: collections.deque = collections.deque()
def put(self, key, item):
if self.has(key):
self.keys.remove(key)
self.keys.append(key)
self.contents[key] = item
while len(self.contents) > self.pool_size:
self.contents.pop(self.keys.popleft())
def get(self, key):
return self.contents[key]
def has(self, key):
return key in self.contents
class DataWrapper:
def __init__(
self,
feature_dataset: DatasetH,
backtest_dataset: DatasetH,
columns_today: List[str],
columns_yesterday: List[str],
_internal: bool = False,
):
assert _internal, "Init function of data wrapper is for internal use only."
self.feature_dataset = feature_dataset
self.backtest_dataset = backtest_dataset
self.columns_today = columns_today
self.columns_yesterday = columns_yesterday
# TODO: We might have the chance to merge them.
self.feature_cache = LRUCache()
self.backtest_cache = LRUCache()
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
if backtest:
dataset = self.backtest_dataset
cache = self.backtest_cache
else:
dataset = self.feature_dataset
cache = self.feature_cache
if cache.has((start_time, end_time, stock_id)):
return cache.get((start_time, end_time, stock_id))
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
cache.put((start_time, end_time, stock_id), data)
return data
def init_qlib(config: dict, part: Optional[str] = None) -> None:
provider_uri_map = {
"day": config["provider_uri_day"].as_posix(),
"1min": config["provider_uri_1min"].as_posix(),
}
qlib.init(
region=REG_CN,
auto_mount=False,
custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum],
expression_cache=None,
calendar_provider={
"class": "LocalCalendarProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileCalendarStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
},
},
},
feature_provider={
"class": "LocalFeatureProvider",
"module_path": "qlib.data.data",
"kwargs": {
"backend": {
"class": "FileFeatureStorage",
"module_path": "qlib.data.storage.file_storage",
"kwargs": {"provider_uri_map": provider_uri_map},
},
},
},
provider_uri=provider_uri_map,
kernels=1,
redis_port=-1,
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
)

Просмотреть файл

@ -3,13 +3,13 @@
from __future__ import annotations
from typing import TYPE_CHECKING, TypeVar, Generic, Any
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
import numpy as np
from qlib.typehint import final
from .simulator import StateType, ActType
from .simulator import ActType, StateType
if TYPE_CHECKING:
from .utils.env_wrapper import EnvWrapper
@ -40,7 +40,7 @@ class Interpreter:
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
env: EnvWrapper | None = None
env: Optional[EnvWrapper] = None
@property
def observation_space(self) -> gym.Space:
@ -74,7 +74,7 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter):
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""
env: "EnvWrapper" | None = None
env: Optional[EnvWrapper] = None
@property
def action_space(self) -> gym.Space:
@ -141,10 +141,10 @@ def _gym_space_contains(space: gym.Space, x: Any) -> None:
class GymSpaceValidationError(Exception):
def __init__(self, message: str, space: gym.Space, x: Any):
def __init__(self, message: str, space: gym.Space, x: Any) -> None:
self.message = message
self.space = space
self.x = x
def __str__(self):
def __str__(self) -> str:
return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"

Просмотреть файл

@ -5,15 +5,15 @@ from __future__ import annotations
import math
from pathlib import Path
from typing import Any, cast
from typing import Any, List, cast
import numpy as np
import pandas as pd
from gym import spaces
from qlib.constant import EPS
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.data import pickle_styled
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.typehint import TypedDict
from .simulator_simple import SAOEState
@ -99,18 +99,18 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
"data_processed": self._mask_future_info(processed.today, state.cur_time),
"data_processed_prev": processed.yesterday,
"acquiring": state.order.direction == state.order.BUY,
"cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1),
"cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1),
"cur_step": min(self.env.status["cur_step"], self.max_step - 1),
"num_step": self.max_step,
"target": state.order.amount,
"position": state.position,
"position_history": position_history[: self.max_step],
}
},
),
)
@property
def observation_space(self):
def observation_space(self) -> spaces.Dict:
space = {
"data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
"data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
@ -147,11 +147,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
The key list is not full. You can add more if more information is needed by your policy.
"""
def __init__(self, max_step: int):
def __init__(self, max_step: int) -> None:
self.max_step = max_step
@property
def observation_space(self):
def observation_space(self) -> spaces.Dict:
space = {
"acquiring": spaces.Discrete(2),
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
@ -165,13 +165,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
assert self.env is not None
assert self.env.status["cur_step"] <= self.max_step
obs = CurrentStateObs(
{
"acquiring": state.order.direction == state.order.BUY,
"cur_step": self.env.status["cur_step"],
"num_step": self.max_step,
"target": state.order.amount,
"position": state.position,
}
acquiring=state.order.direction == state.order.BUY,
cur_step=self.env.status["cur_step"],
num_step=self.max_step,
target=state.order.amount,
position=state.position,
)
return obs
@ -188,7 +186,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
"""
def __init__(self, values: int | list[float]):
def __init__(self, values: int | List[float]) -> None:
if isinstance(values, int):
values = [i / values for i in range(0, values + 1)]
self.action_values = values
@ -203,7 +201,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
"""Convert a continous ratio to deal amount.
"""Convert a continuous ratio to deal amount.
The ratio is relative to TWAP on the remainder of the day.
For example, there are 5 steps left, and the left position is 300.

Просмотреть файл

@ -3,13 +3,14 @@
from __future__ import annotations
from typing import cast
from typing import List, Tuple, cast
import torch
import torch.nn as nn
from tianshou.data import Batch
from qlib.typehint import Literal
from .interpreter import FullHistoryObs
__all__ = ["Recurrent"]
@ -18,7 +19,7 @@ __all__ = ["Recurrent"]
class Recurrent(nn.Module):
"""The network architecture proposed in `OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_.
At every timestep the input of policy network is divided into two parts,
At every time step the input of policy network is divided into two parts,
the public variables and the private variables. which are handled by ``raw_rnn``
and ``pri_rnn`` in this network, respectively.
@ -33,7 +34,7 @@ class Recurrent(nn.Module):
output_dim: int = 32,
rnn_type: Literal["rnn", "lstm", "gru"] = "gru",
rnn_num_layers: int = 1,
):
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
@ -62,10 +63,10 @@ class Recurrent(nn.Module):
nn.ReLU(),
)
def _init_extra_branches(self):
def _init_extra_branches(self) -> None:
pass
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]:
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]:
bs, _, data_dim = obs["data_processed"].size()
data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1)
cur_step = obs["cur_step"].long()

Просмотреть файл

@ -1,16 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from pathlib import Path
from typing import Optional, cast
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, cast
import numpy as np
import gym
import numpy as np
import torch
import torch.nn as nn
from gym.spaces import Discrete
from tianshou.data import Batch, to_torch
from tianshou.policy import PPOPolicy, BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.policy import BasePolicy, PPOPolicy
__all__ = ["AllOne", "PPO"]
@ -18,29 +19,39 @@ __all__ = ["AllOne", "PPO"]
# baselines #
class NonlearnablePolicy(BasePolicy):
class NonLearnablePolicy(BasePolicy):
"""Tianshou's BasePolicy with empty ``learn`` and ``process_fn``.
This could be moved outside in future.
"""
def __init__(self, obs_space: gym.Space, action_space: gym.Space):
def __init__(self, obs_space: gym.Space, action_space: gym.Space) -> None:
super().__init__()
def learn(self, batch, batch_size, repeat):
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
pass
def process_fn(self, batch, buffer, indice):
def process_fn(
self,
batch: Batch,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> Batch:
pass
class AllOne(NonlearnablePolicy):
class AllOne(NonLearnablePolicy):
"""Forward returns a batch full of 1.
Useful when implementing some baselines (e.g., TWAP).
"""
def forward(self, batch, state=None, **kwargs):
def forward(
self,
batch: Batch,
state: dict | Batch | np.ndarray = None,
**kwargs: Any,
) -> Batch:
return Batch(act=np.full(len(batch), 1.0), state=state)
@ -48,24 +59,34 @@ class AllOne(NonlearnablePolicy):
class PPOActor(nn.Module):
def __init__(self, extractor: nn.Module, action_dim: int):
def __init__(self, extractor: nn.Module, action_dim: int) -> None:
super().__init__()
self.extractor = extractor
self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1))
def forward(self, obs, state=None, info={}):
def forward(
self,
obs: torch.Tensor,
state: torch.Tensor = None,
info: dict = {},
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
feature = self.extractor(to_torch(obs, device=auto_device(self)))
out = self.layer_out(feature)
return out, state
class PPOCritic(nn.Module):
def __init__(self, extractor: nn.Module):
def __init__(self, extractor: nn.Module) -> None:
super().__init__()
self.extractor = extractor
self.value_out = nn.Linear(cast(int, extractor.output_dim), 1)
def forward(self, obs, state=None, info={}):
def forward(
self,
obs: torch.Tensor,
state: torch.Tensor = None,
info: dict = {},
) -> torch.Tensor:
feature = self.extractor(to_torch(obs, device=auto_device(self)))
return self.value_out(feature).squeeze(dim=-1)
@ -93,18 +114,20 @@ class PPO(PPOPolicy):
max_grad_norm: float = 100.0,
reward_normalization: bool = True,
eps_clip: float = 0.3,
value_clip: float = True,
value_clip: bool = True,
vf_coef: float = 1.0,
gae_lambda: float = 1.0,
max_batchsize: int = 256,
max_batch_size: int = 256,
deterministic_eval: bool = True,
weight_file: Optional[Path] = None,
):
) -> None:
assert isinstance(action_space, Discrete)
actor = PPOActor(network, action_space.n)
critic = PPOCritic(network)
optimizer = torch.optim.Adam(
chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay
chain_dedup(actor.parameters(), critic.parameters()),
lr=lr,
weight_decay=weight_decay,
)
super().__init__(
actor,
@ -118,7 +141,7 @@ class PPO(PPOPolicy):
value_clip=value_clip,
vf_coef=vf_coef,
gae_lambda=gae_lambda,
max_batchsize=max_batchsize,
max_batchsize=max_batch_size,
deterministic_eval=deterministic_eval,
observation_space=obs_space,
action_space=action_space,
@ -136,7 +159,7 @@ def auto_device(module: nn.Module) -> torch.device:
return torch.device("cpu") # fallback to cpu
def load_weight(policy, path):
def load_weight(policy: nn.Module, path: Path) -> None:
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
loaded_weight = torch.load(path, map_location="cpu")
try:
@ -149,7 +172,7 @@ def load_weight(policy, path):
policy.load_state_dict(loaded_weight)
def chain_dedup(*iterables):
def chain_dedup(*iterables: Iterable) -> Generator[Any, None, None]:
seen = set()
for iterable in iterables:
for i in iterable:

Просмотреть файл

@ -6,9 +6,10 @@ from __future__ import annotations
from typing import cast
import numpy as np
from qlib.rl.reward import Reward
from .simulator_simple import SAOEState, SAOEMetrics
from .simulator_simple import SAOEMetrics, SAOEState
__all__ = ["PAPenaltyReward"]

Просмотреть файл

@ -1,4 +1,424 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Placeholder for qlib-based simulator."""
from __future__ import annotations
from typing import Any, Callable, cast, Generator, List, Optional, Tuple
import numpy as np
import pandas as pd
from qlib.backtest.decision import BaseTradeDecision, Order, OrderHelper, TradeDecisionWO, TradeRange, TradeRangeByTime
from qlib.backtest.executor import BaseExecutor, NestedExecutor
from qlib.backtest.utils import CommonInfrastructure
from qlib.constant import EPS
from qlib.rl.data.exchange_wrapper import QlibIntradayBacktestData
from qlib.rl.from_neutrader.config import ExchangeConfig
from qlib.rl.from_neutrader.feature import init_qlib
from qlib.rl.order_execution.simulator_simple import SAOEMetrics, SAOEState
from qlib.rl.order_execution.utils import (
dataframe_append,
get_common_infra,
get_portfolio_and_indicator,
get_ticks_slice,
price_advantage,
)
from qlib.rl.simulator import Simulator
from qlib.strategy.base import BaseStrategy
class DecomposedStrategy(BaseStrategy):
def __init__(self) -> None:
super().__init__()
self.execute_order: Optional[Order] = None
self.execute_result: List[Tuple[Order, float, float, float]] = []
def generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
# Once the following line is executed, this DecomposedStrategy (self) will be yielded to the outside
# of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`,
# the sent item will be captured by `exec_vol`. The outside policy could communicate with the inner
# level strategy through this way.
exec_vol = yield self
oh = self.trade_exchange.get_order_helper()
order = oh.create(self._order.stock_id, exec_vol, self._order.direction)
self.execute_order = order
return TradeDecisionWO([order], self)
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
return outer_trade_decision
def post_exe_step(self, execute_result: list) -> None:
self.execute_result = execute_result
def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs: Any) -> None:
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
if outer_trade_decision is not None:
order_list = outer_trade_decision.order_list
assert len(order_list) == 1
self._order = order_list[0]
class SingleOrderStrategy(BaseStrategy):
# this logic is copied from FileOrderStrategy
def __init__(
self,
common_infra: CommonInfrastructure,
order: Order,
trade_range: TradeRange,
instrument: str,
) -> None:
super().__init__(common_infra=common_infra)
self._order = order
self._trade_range = trade_range
self._instrument = instrument
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
return outer_trade_decision
def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO:
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
order_list = [
oh.create(
code=self._instrument,
amount=self._order.amount,
direction=self._order.direction,
),
]
return TradeDecisionWO(order_list, self, self._trade_range)
# TODO: move these to the configuration files
FINEST_GRANULARITY = "1min"
COARSEST_GRANULARITY = "1day"
class StateMaintainer:
"""
Maintain states of the environment.
Example usage::
maintainer = StateMaintainer(...) # in reset
maintainer.update(...) # in step
# get states in get_state from maintainer
"""
def __init__(self, order: Order, time_per_step: str, tick_index: pd.DatetimeIndex, twap_price: float) -> None:
super().__init__()
self.position = order.amount
self._order = order
self._time_per_step = time_per_step
self._tick_index = tick_index
self._twap_price = twap_price
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.metrics: Optional[SAOEMetrics] = None
def update(
self,
inner_executor: BaseExecutor,
inner_strategy: DecomposedStrategy,
done: bool,
all_indicators: dict,
) -> None:
execute_order = inner_strategy.execute_order
execute_result = inner_strategy.execute_result
exec_vol = np.array([e[0].deal_amount for e in execute_result])
num_step = len(execute_result)
assert execute_order is not None
if num_step == 0:
market_volume = np.array([])
market_price = np.array([])
datetime_list = pd.DatetimeIndex([])
else:
market_volume = np.array(
inner_executor.trade_exchange.get_volume(
execute_order.stock_id,
execute_result[0][0].start_time,
execute_result[-1][0].start_time,
method=None,
),
)
trade_value = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["value"].values
deal_amount = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["deal_amount"].values
market_price = trade_value / deal_amount
datetime_list = all_indicators[FINEST_GRANULARITY].index[-num_step:]
assert market_price.shape == market_volume.shape == exec_vol.shape
self.history_exec = dataframe_append(
self.history_exec,
self._collect_multi_order_metric(
order=self._order,
datetime=datetime_list,
market_vol=market_volume,
market_price=market_price,
exec_vol=exec_vol,
pa=all_indicators[self._time_per_step].iloc[-1]["pa"],
),
)
self.history_steps = dataframe_append(
self.history_steps,
[
self._collect_single_order_metric(
execute_order,
execute_order.start_time,
market_volume,
market_price,
exec_vol.sum(),
exec_vol,
),
],
)
if done:
self.metrics = self._collect_single_order_metric(
self._order,
self._tick_index[0], # start time
self.history_exec["market_volume"],
self.history_exec["market_price"],
self.history_steps["amount"].sum(),
self.history_exec["deal_amount"],
)
# TODO: check whether we need this. Can we get this information from Account?
# Do this at the end
self.position -= exec_vol.sum()
def _collect_multi_order_metric(
self,
order: Order,
datetime: pd.Timestamp,
market_vol: np.ndarray,
market_price: np.ndarray,
exec_vol: np.ndarray,
pa: float,
) -> SAOEMetrics:
return SAOEMetrics(
# It should have the same keys with SAOEMetrics,
# but the values do not necessarily have the annotated type.
# Some values could be vectorized (e.g., exec_vol).
stock_id=order.stock_id,
datetime=datetime,
direction=order.direction,
market_volume=market_vol,
market_price=market_price,
amount=exec_vol,
inner_amount=exec_vol,
deal_amount=exec_vol,
trade_price=market_price,
trade_value=market_price * exec_vol,
position=self.position - np.cumsum(exec_vol),
ffr=exec_vol / order.amount,
pa=pa,
)
def _collect_single_order_metric(
self,
order: Order,
datetime: pd.Timestamp,
market_vol: np.ndarray,
market_price: np.ndarray,
amount: float, # intended to trade such amount
exec_vol: np.ndarray,
) -> SAOEMetrics:
assert len(market_vol) == len(market_price) == len(exec_vol)
if np.abs(np.sum(exec_vol)) < EPS:
exec_avg_price = 0.0
else:
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
if hasattr(exec_avg_price, "item"): # could be numpy scalar
exec_avg_price = exec_avg_price.item() # type: ignore
exec_sum = exec_vol.sum()
return SAOEMetrics(
stock_id=order.stock_id,
datetime=datetime,
direction=order.direction,
market_volume=market_vol.sum(),
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
amount=amount,
inner_amount=exec_sum,
deal_amount=exec_sum, # in this simulator, there's no other restrictions
trade_price=exec_avg_price,
trade_value=float(np.sum(market_price * exec_vol)),
position=self.position - exec_sum,
ffr=float(exec_sum / order.amount),
pa=price_advantage(exec_avg_price, self._twap_price, order.direction),
)
class SingleAssetOrderExecutionQlib(Simulator[Order, SAOEState, float]):
"""Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.
Parameters
----------
order (Order):
The seed to start an SAOE simulator is an order.
time_per_step (str):
A string to describe the time granularity of each step. Current support "1min", "30min", and "1day"
qlib_config (dict):
Configuration used to initialize Qlib.
inner_executor_fn (Callable[[str, CommonInfrastructure], BaseExecutor]):
Function used to get the inner level executor.
exchange_config (ExchangeConfig):
Configuration used to create the Exchange instance.
"""
def __init__(
self,
order: Order,
time_per_step: str, # "1min", "30min", "1day"
qlib_config: dict,
inner_executor_fn: Callable[[str, CommonInfrastructure], BaseExecutor],
exchange_config: ExchangeConfig,
) -> None:
assert time_per_step in ("1min", "30min", "1day")
super().__init__(initial=order)
assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same."
self._order = order
self._order_date = pd.Timestamp(order.start_time.date())
self._trade_range = TradeRangeByTime(order.start_time.time(), order.end_time.time())
self._qlib_config = qlib_config
self._inner_executor_fn = inner_executor_fn
self._exchange_config = exchange_config
self._time_per_step = time_per_step
self._ticks_per_step = int(pd.Timedelta(time_per_step).total_seconds() // 60)
self._executor: Optional[NestedExecutor] = None
self._collect_data_loop: Optional[Generator] = None
self._done = False
self._inner_strategy = DecomposedStrategy()
self.reset(self._order)
def reset(self, order: Order) -> None:
instrument = order.stock_id
# TODO: Check this logic. Make sure we need to do this every time we reset the simulator.
init_qlib(self._qlib_config, instrument)
common_infra = get_common_infra(
self._exchange_config,
trade_date=pd.Timestamp(self._order_date),
codes=[instrument],
)
# TODO: We can leverage interfaces like (https://tinyurl.com/y8f8fhv4) to create trading environment.
# TODO: By aligning the interface to create environments with Qlib, it will be easier to share the config and
# TODO: code between backtesting and training.
self._inner_executor = self._inner_executor_fn(self._time_per_step, common_infra)
self._executor = NestedExecutor(
time_per_step=COARSEST_GRANULARITY,
inner_executor=self._inner_executor,
inner_strategy=self._inner_strategy,
track_data=True,
common_infra=common_infra,
)
exchange = self._inner_executor.trade_exchange
self._ticks_index = pd.DatetimeIndex([e[1] for e in list(exchange.quote_df.index)])
self._ticks_for_order = get_ticks_slice(
self._ticks_index,
self._order.start_time,
self._order.end_time,
include_end=True,
)
self._backtest_data = QlibIntradayBacktestData(
order=self._order,
exchange=exchange,
start_time=self._ticks_for_order[0],
end_time=self._ticks_for_order[-1],
)
self.twap_price = self._backtest_data.get_deal_price().mean()
top_strategy = SingleOrderStrategy(common_infra, order, self._trade_range, instrument)
self._executor.reset(start_time=pd.Timestamp(self._order_date), end_time=pd.Timestamp(self._order_date))
top_strategy.reset(level_infra=self._executor.get_level_infra())
self._collect_data_loop = self._executor.collect_data(top_strategy.generate_trade_decision(), level=0)
assert isinstance(self._collect_data_loop, Generator)
self._iter_strategy(action=None)
self._done = False
self._maintainer = StateMaintainer(
order=self._order,
time_per_step=self._time_per_step,
tick_index=self._ticks_index,
twap_price=self.twap_price,
)
def _iter_strategy(self, action: float = None) -> DecomposedStrategy:
"""Iterate the _collect_data_loop until we get the next yield DecomposedStrategy."""
assert self._collect_data_loop is not None
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
while not isinstance(strategy, DecomposedStrategy):
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
assert isinstance(strategy, DecomposedStrategy)
return strategy
def step(self, action: float) -> None:
"""Execute one step or SAOE.
Parameters
----------
action (float):
The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.
"""
assert not self._done, "Simulator has already done!"
try:
self._iter_strategy(action=action)
except StopIteration:
self._done = True
assert self._executor is not None
_, all_indicators = get_portfolio_and_indicator(self._executor)
self._maintainer.update(
inner_executor=self._inner_executor,
inner_strategy=self._inner_strategy,
done=self._done,
all_indicators=all_indicators,
)
def get_state(self) -> SAOEState:
return SAOEState(
order=self._order,
cur_time=self._inner_executor.trade_calendar.get_step_time()[0],
position=self._maintainer.position,
history_exec=self._maintainer.history_exec,
history_steps=self._maintainer.history_steps,
metrics=self._maintainer.metrics,
backtest_data=self._backtest_data,
ticks_per_step=self._ticks_per_step,
ticks_index=self._ticks_index,
ticks_for_order=self._ticks_for_order,
)
def done(self) -> bool:
return self._done

Просмотреть файл

@ -4,18 +4,20 @@
from __future__ import annotations
from pathlib import Path
from typing import NamedTuple, Any, TypeVar, cast
from typing import Any, NamedTuple, Optional, TypeVar, cast
import numpy as np
import pandas as pd
from qlib.backtest.decision import Order, OrderDir
from qlib.constant import EPS
from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_simple_intraday_backtest_data
from qlib.rl.simulator import Simulator
from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType
from qlib.rl.utils import LogLevel
from qlib.typehint import TypedDict
# TODO: Integrating Qlib's native data with simulator_simple
__all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"]
ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point
@ -33,40 +35,40 @@ class SAOEMetrics(TypedDict):
stock_id: str
"""Stock ID of this record."""
datetime: pd.Timestamp
datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this
"""Datetime of this record (this is index in the dataframe)."""
direction: int
"""Direction of the order. 0 for sell, 1 for buy."""
# Market information.
market_volume: float
market_volume: np.ndarray | float
"""(total) market volume traded in the period."""
market_price: float
market_price: np.ndarray | float
"""Deal price. If it's a period of time, this is the average market deal price."""
# Strategy records.
amount: float
amount: np.ndarray | float
"""Total amount (volume) strategy intends to trade."""
inner_amount: float
inner_amount: np.ndarray | float
"""Total amount that the lower-level strategy intends to trade
(might be larger than amount, e.g., to ensure ffr)."""
deal_amount: float
deal_amount: np.ndarray | float
"""Amount that successfully takes effect (must be less than inner_amount)."""
trade_price: float
trade_price: np.ndarray | float
"""The average deal price for this strategy."""
trade_value: float
"""Total worth of trading. In the simple simulaton, trade_value = deal_amount * price."""
position: float
trade_value: np.ndarray | float
"""Total worth of trading. In the simple simulation, trade_value = deal_amount * price."""
position: np.ndarray | float
"""Position left after this "period"."""
# Accumulated metrics
ffr: float
ffr: np.ndarray | float
"""Completed how much percent of the daily order."""
pa: float
pa: np.ndarray | float
"""Price advantage compared to baseline (i.e., trade with baseline market price).
The baseline is trade price when using TWAP strategy to execute this order.
Please note that there could be data leak here).
@ -87,7 +89,7 @@ class SAOEState(NamedTuple):
history_steps: pd.DataFrame
"""See :attr:`SingleAssetOrderExecution.history_steps`."""
metrics: SAOEMetrics | None
metrics: Optional[SAOEMetrics]
"""Daily metric, only available when the trading is in "done" state."""
backtest_data: IntradayBacktestData
@ -114,13 +116,13 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
If such fine granularity is not needed, use ``ticks_per_step`` to
lengthen the ticks for each step.
In each step, the traded amount are "equally" splitted to each tick,
then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``),
In each step, the traded amount are "equally" separated to each tick,
then bounded by volume maximum execution volume (i.e., ``vol_threshold``),
and if it's the last step, try to ensure all the amount to be executed.
Parameters
----------
initial
order
The seed to start an SAOE simulator is an order.
ticks_per_step
How many ticks per step.
@ -140,7 +142,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
See :class:`SAOEMetrics` for available columns.
Index is ``datetime``, which is the **starting** time of each step."""
metrics: SAOEMetrics | None
metrics: Optional[SAOEMetrics]
"""Metrics. Only available when done."""
twap_price: float
@ -159,15 +161,21 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
data_dir: Path,
ticks_per_step: int = 30,
deal_price_type: DealPriceType = "close",
vol_threshold: float | None = None,
vol_threshold: Optional[float] = None,
) -> None:
super().__init__(initial=order)
self.order = order
self.ticks_per_step: int = ticks_per_step
self.deal_price_type = deal_price_type
self.vol_threshold = vol_threshold
self.data_dir = data_dir
self.backtest_data = load_intraday_backtest_data(
self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction
self.backtest_data = load_simple_intraday_backtest_data(
self.data_dir,
order.stock_id,
pd.Timestamp(order.start_time.date()),
self.deal_price_type,
order.direction,
)
self.ticks_index = self.backtest_data.get_time_index()
@ -188,9 +196,9 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
self.metrics = None
self.market_price: np.ndarray | None = None
self.market_vol: np.ndarray | None = None
self.market_vol_limit: np.ndarray | None = None
self.market_price: Optional[np.ndarray] = None
self.market_vol: Optional[np.ndarray] = None
self.market_vol_limit: Optional[np.ndarray] = None
def step(self, amount: float) -> None:
"""Execute one step or SAOE.
@ -205,7 +213,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
self.market_price = self.market_vol = None # avoid misuse
exec_vol = self._split_exec_vol(amount)
assert self.market_price is not None and self.market_vol is not None
assert self.market_price is not None
assert self.market_vol is not None
ticks_position = self.position - np.cumsum(exec_vol)
@ -363,7 +372,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
inner_amount=exec_vol.sum(),
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
trade_price=exec_avg_price,
trade_value=np.sum(market_price * exec_vol),
trade_value=float(np.sum(market_price * exec_vol)),
position=self.position,
ffr=float(exec_vol.sum() / self.order.amount),
pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction),
@ -386,7 +395,9 @@ _float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
def price_advantage(
exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int
exec_price: _float_or_ndarray,
baseline_price: float,
direction: OrderDir | int,
) -> _float_or_ndarray:
if baseline_price == 0: # something is wrong with data. Should be nan here
if isinstance(exec_price, float):

Просмотреть файл

@ -0,0 +1,111 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Any, List, Tuple, cast
import numpy as np
import pandas as pd
from qlib.backtest import CommonInfrastructure, get_exchange
from qlib.backtest.account import Account
from qlib.backtest.decision import OrderDir
from qlib.backtest.executor import BaseExecutor
from qlib.rl.from_neutrader.config import ExchangeConfig
from qlib.rl.order_execution.simulator_simple import ONE_SEC, _float_or_ndarray
from qlib.utils.time import Freq
def get_common_infra(
config: ExchangeConfig,
trade_date: pd.Timestamp,
codes: List[str],
cash_limit: float = None,
) -> CommonInfrastructure:
# need to specify a range here for acceleration
if cash_limit is None:
trade_account = Account(init_cash=int(1e12), benchmark_config={}, pos_type="InfPosition")
else:
trade_account = Account(
init_cash=cash_limit,
benchmark_config={},
pos_type="Position",
position_dict={code: {"amount": 1e12, "price": 1.0} for code in codes},
)
exchange = get_exchange(
codes=codes,
freq="1min",
limit_threshold=config.limit_threshold,
deal_price=config.deal_price,
open_cost=config.open_cost,
close_cost=config.close_cost,
min_cost=config.min_cost if config.trade_unit is not None else 0,
start_time=trade_date,
end_time=trade_date + pd.DateOffset(1),
trade_unit=config.trade_unit,
volume_threshold=config.volume_threshold,
)
return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
def get_ticks_slice(
ticks_index: pd.DatetimeIndex,
start: pd.Timestamp,
end: pd.Timestamp,
include_end: bool = False,
) -> pd.DatetimeIndex:
if not include_end:
end = end - ONE_SEC
return ticks_index[ticks_index.slice_indexer(start, end)]
def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
# dataframe.append is deprecated
other_df = pd.DataFrame(other).set_index("datetime")
other_df.index.name = "datetime"
res = pd.concat([df, other_df], axis=0)
return res
def price_advantage(
exec_price: _float_or_ndarray,
baseline_price: float,
direction: OrderDir | int,
) -> _float_or_ndarray:
if baseline_price == 0: # something is wrong with data. Should be nan here
if isinstance(exec_price, float):
return 0.0
else:
return np.zeros_like(exec_price)
if direction == OrderDir.BUY:
res = (1 - exec_price / baseline_price) * 10000
elif direction == OrderDir.SELL:
res = (exec_price / baseline_price - 1) * 10000
else:
raise ValueError(f"Unexpected order direction: {direction}")
res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0)
if res_wo_nan.size == 1:
return res_wo_nan.item()
else:
return cast(_float_or_ndarray, res_wo_nan)
def get_portfolio_and_indicator(executor: BaseExecutor) -> Tuple[dict, dict]:
all_executors = executor.get_all_executors()
all_portfolio_metrics = {
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
for _executor in all_executors
if _executor.trade_account.is_port_metr_enabled()
}
all_indicators = {}
for _executor in all_executors:
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
return all_portfolio_metrics, all_indicators

Просмотреть файл

@ -3,7 +3,7 @@
from __future__ import annotations
from typing import Generic, Any, TypeVar, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, TypeVar
from qlib.typehint import final
@ -20,7 +20,7 @@ class Reward(Generic[SimulatorState]):
Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe.
"""
env: EnvWrapper | None = None
env: Optional[EnvWrapper] = None
@final
def __call__(self, simulator_state: SimulatorState) -> float:
@ -30,14 +30,15 @@ class Reward(Generic[SimulatorState]):
"""Implement this method for your own reward."""
raise NotImplementedError("Implement reward calculation recipe in `reward()`.")
def log(self, name, value):
def log(self, name: str, value: Any) -> None:
assert self.env is not None
self.env.logger.add_scalar(name, value)
class RewardCombination(Reward):
"""Combination of multiple reward."""
def __init__(self, rewards: dict[str, tuple[Reward, float]]):
def __init__(self, rewards: Dict[str, Tuple[Reward, float]]) -> None:
self.rewards = rewards
def reward(self, simulator_state: Any) -> float:

Просмотреть файл

@ -3,7 +3,7 @@
from __future__ import annotations
from typing import TypeVar, Generic, Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from .seed import InitialStateType
@ -49,7 +49,7 @@ class Simulator(Generic[InitialStateType, StateType, ActType]):
Simulators are discouraged to use this, because it's prone to induce errors.
"""
env: EnvWrapper | None = None
env: Optional[EnvWrapper] = None
def __init__(self, initial: InitialStateType, **kwargs: Any) -> None:
pass

Просмотреть файл

@ -3,17 +3,17 @@
from __future__ import annotations
from typing import Callable, Sequence, cast, Any
from typing import Any, Callable, Sequence, cast
from tianshou.policy import BasePolicy
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
from qlib.rl.reward import Reward
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.utils import FiniteEnvType, LogWriter
from .vessel import TrainingVessel
from .trainer import Trainer
from .vessel import TrainingVessel
def train(

Просмотреть файл

@ -12,7 +12,7 @@ import shutil
import time
from datetime import datetime
from pathlib import Path
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import numpy as np
import torch

Просмотреть файл

@ -6,13 +6,13 @@ from __future__ import annotations
import copy
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any, Iterable, TypeVar, Sequence, cast
from typing import Any, Iterable, Sequence, TypeVar, cast
import torch
from qlib.rl.simulator import InitialStateType
from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogCollector, LogWriter, LogBuffer, vectorize_env, LogLevel
from qlib.log import get_module_logger
from qlib.rl.simulator import InitialStateType
from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogBuffer, LogCollector, LogLevel, LogWriter, vectorize_env
from qlib.rl.utils.finite_env import FiniteVectorEnv
from qlib.typehint import Literal

Просмотреть файл

@ -4,7 +4,7 @@
from __future__ import annotations
import weakref
from typing import Callable, ContextManager, Generic, Iterable, TYPE_CHECKING, Sequence, Any, TypeVar, cast, Dict
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
import numpy as np
from tianshou.data import Collector, VectorReplayBuffer
@ -12,12 +12,11 @@ from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy
from qlib.constant import INF
from qlib.rl.interpreter import StateType, ActType, ObsType, PolicyActType
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
from qlib.rl.reward import Reward
from qlib.rl.utils import DataQueue
from qlib.log import get_module_logger
from qlib.rl.interpreter import ActionInterpreter, ActType, ObsType, PolicyActType, StateInterpreter, StateType
from qlib.rl.reward import Reward
from qlib.rl.simulator import InitialStateType, Simulator
from qlib.rl.utils import DataQueue
from qlib.rl.utils.finite_env import FiniteVectorEnv
if TYPE_CHECKING:
@ -209,6 +208,9 @@ class TrainingVessel(TrainingVesselBase):
order = np.random.permutation(len(collection))
res = [collection[o] for o in order[:size]]
_logger.info(
"Fast running in development mode. Cut %s initial states from %d to %d.", name, len(collection), len(res)
"Fast running in development mode. Cut %s initial states from %d to %d.",
name,
len(collection),
len(res),
)
return res

Просмотреть файл

@ -1,7 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .data_queue import *
from .env_wrapper import *
from .finite_env import *
from .log import *
from .data_queue import DataQueue
from .env_wrapper import EnvWrapper, EnvWrapperStatus
from .finite_env import FiniteEnvType, vectorize_env
from .log import ConsoleWriter, CsvWriter, LogBuffer, LogCollector, LogLevel, LogWriter
__all__ = [
"LogLevel",
"DataQueue",
"EnvWrapper",
"FiniteEnvType",
"LogCollector",
"LogWriter",
"vectorize_env",
"ConsoleWriter",
"CsvWriter",
"EnvWrapperStatus",
"LogBuffer",
]

Просмотреть файл

@ -1,13 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from __future__ import annotations
import multiprocessing
import os
import threading
import time
import warnings
from queue import Empty
from typing import TypeVar, Generic, Sequence, cast
from typing import Any, Generator, Generic, Sequence, TypeVar, cast
from qlib.log import get_module_logger
@ -60,7 +62,7 @@ class DataQueue(Generic[T]):
shuffle: bool = True,
producer_num_workers: int = 0,
queue_maxsize: int = 0,
):
) -> None:
if queue_maxsize == 0:
if os.cpu_count() is not None:
queue_maxsize = cast(int, os.cpu_count())
@ -78,14 +80,14 @@ class DataQueue(Generic[T]):
self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
self._done = multiprocessing.Value("i", 0)
def __enter__(self):
def __enter__(self) -> DataQueue:
self.activate()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.cleanup()
def cleanup(self):
def cleanup(self) -> None:
with self._done.get_lock():
self._done.value += 1
for repeat in range(500):
@ -105,7 +107,7 @@ class DataQueue(Generic[T]):
break
_logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}")
def get(self, block=True):
def get(self, block: bool = True) -> Any:
if not hasattr(self, "_first_get"):
self._first_get = True
if self._first_get:
@ -120,17 +122,17 @@ class DataQueue(Generic[T]):
if self._done.value:
raise StopIteration # pylint: disable=raise-missing-from
def put(self, obj, block=True, timeout=None):
return self._queue.put(obj, block=block, timeout=timeout)
def put(self, obj: Any, block: bool = True, timeout: int = None) -> None:
self._queue.put(obj, block=block, timeout=timeout)
def mark_as_done(self):
def mark_as_done(self) -> None:
with self._done.get_lock():
self._done.value = 1
def done(self):
def done(self) -> int:
return self._done.value
def activate(self):
def activate(self) -> DataQueue:
if self._activated:
raise ValueError("DataQueue can not activate twice.")
thread = threading.Thread(target=self._producer, daemon=True)
@ -138,20 +140,20 @@ class DataQueue(Generic[T]):
self._activated = True
return self
def __del__(self):
def __del__(self) -> None:
_logger.debug(f"__del__ of {__name__}.DataQueue")
self.cleanup()
def __iter__(self):
def __iter__(self) -> Generator[Any, None, None]:
if not self._activated:
raise ValueError(
"Need to call activate() to launch a daemon worker "
"to produce data into data queue before using it. "
"You probably have forgotten to use the DataQueue in a with block."
"You probably have forgotten to use the DataQueue in a with block.",
)
return self._consumer()
def _consumer(self):
def _consumer(self) -> Generator[Any, None, None]:
while True:
try:
yield self.get()
@ -159,7 +161,7 @@ class DataQueue(Generic[T]):
_logger.debug("Data consumer timed-out from get.")
return
def _producer(self):
def _producer(self) -> None:
# pytorch dataloader is used here only because we need its sampler and multi-processing
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel

Просмотреть файл

@ -4,14 +4,15 @@
from __future__ import annotations
import weakref
from typing import Callable, Any, Iterable, Iterator, Generic, cast
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast
import gym
from gym import Space
from qlib.rl.aux_info import AuxiliaryInfoCollector
from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType
from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, StateInterpreter
from qlib.rl.reward import Reward
from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType
from qlib.typehint import TypedDict
from .finite_env import generate_nan_observation
@ -28,7 +29,7 @@ class InfoDict(TypedDict):
aux_info: dict
"""Any information depends on auxiliary info collector."""
log: dict[str, Any]
log: Dict[str, Any]
"""Collected by LogCollector."""
@ -42,14 +43,15 @@ class EnvWrapperStatus(TypedDict):
cur_step: int
done: bool
initial_state: Any | None
initial_state: Optional[Any]
obs_history: list
action_history: list
reward_history: list
class EnvWrapper(
gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]
gym.Env[ObsType, PolicyActType],
Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType],
):
"""Qlib-based RL environment, subclassing ``gym.Env``.
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
@ -97,11 +99,11 @@ class EnvWrapper(
simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]],
state_interpreter: StateInterpreter[StateType, ObsType],
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
seed_iterator: Iterable[InitialStateType] | None,
reward_fn: Reward | None = None,
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
logger: LogCollector | None = None,
):
seed_iterator: Optional[Iterable[InitialStateType]],
reward_fn: Reward = None,
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] = None,
logger: LogCollector = None,
) -> None:
# Assign weak reference to wrapper.
#
# Use weak reference here, because:
@ -135,11 +137,11 @@ class EnvWrapper(
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
@property
def action_space(self):
def action_space(self) -> Space:
return self.action_interpreter.action_space
@property
def observation_space(self):
def observation_space(self) -> Space:
return self.state_interpreter.observation_space
def reset(self, **kwargs: Any) -> ObsType:
@ -191,7 +193,7 @@ class EnvWrapper(
self.seed_iterator = None
return generate_nan_observation(self.observation_space)
def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]:
def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, float, bool, InfoDict]:
"""Environment step.
See the code along with comments to get a sequence of things happening here.
@ -245,5 +247,5 @@ class EnvWrapper(
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
return obs, rew, done, info_dict
def render(self):
def render(self, mode: str = "human") -> None:
raise NotImplementedError("Render is not implemented in EnvWrapper.")

Просмотреть файл

@ -11,11 +11,10 @@ from __future__ import annotations
import copy
import warnings
from contextlib import contextmanager
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Set, Tuple, Type, Union
import gym
import numpy as np
from typing import Any, Set, Callable, Type
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
from qlib.typehint import Literal
@ -32,11 +31,11 @@ __all__ = [
"vectorize_env",
]
FiniteEnvType = Literal["dummy", "subproc", "shmem"]
T = Union[dict, list, tuple, np.ndarray]
def fill_invalid(obj):
def fill_invalid(obj: int | float | bool | T) -> T:
if isinstance(obj, (int, float, bool)):
return fill_invalid(np.array(obj))
if hasattr(obj, "dtype"):
@ -55,11 +54,11 @@ def fill_invalid(obj):
raise ValueError(f"Unsupported value to fill with invalid: {obj}")
def is_invalid(arr):
if hasattr(arr, "dtype"):
def is_invalid(arr: int | float | bool | T) -> bool:
if isinstance(arr, np.ndarray):
if np.issubdtype(arr.dtype, np.floating):
return np.isnan(arr).all()
return (np.iinfo(arr.dtype).max == arr).all()
return cast(bool, cast(np.ndarray, np.iinfo(arr.dtype).max == arr).all())
if isinstance(arr, dict):
return all(is_invalid(o) for o in arr.values())
if isinstance(arr, (list, tuple)):
@ -140,44 +139,44 @@ class FiniteVectorEnv(BaseVectorEnv):
self._collector_guarded: bool = False
def _reset_alive_envs(self):
def _reset_alive_envs(self) -> None:
if not self._alive_env_ids:
# starting or running out
self._alive_env_ids = set(range(self.env_num))
# to workaround with tianshou's buffer and batch
def _set_default_obs(self, obs):
def _set_default_obs(self, obs: Any) -> None:
if obs is not None and self._default_obs is None:
self._default_obs = copy.deepcopy(obs)
def _set_default_info(self, info):
def _set_default_info(self, info: Any) -> None:
if info is not None and self._default_info is None:
self._default_info = copy.deepcopy(info)
def _set_default_rew(self, rew):
def _set_default_rew(self, rew: Any) -> None:
if rew is not None and self._default_rew is None:
self._default_rew = copy.deepcopy(rew)
def _get_default_obs(self):
def _get_default_obs(self) -> Any:
return copy.deepcopy(self._default_obs)
def _get_default_info(self):
def _get_default_info(self) -> Any:
return copy.deepcopy(self._default_info)
def _get_default_rew(self):
def _get_default_rew(self) -> Any:
return copy.deepcopy(self._default_rew)
# END
@staticmethod
def _postproc_env_obs(obs):
def _postproc_env_obs(obs: Any) -> Optional[Any]:
# reserved for shmem vector env to restore empty observation
if obs is None or check_nan_observation(obs):
return None
return obs
@contextmanager
def collector_guard(self):
def collector_guard(self) -> Generator[FiniteVectorEnv, None, None]:
"""Guard the collector. Recommended to guard every collect.
This guard is for two purposes.
@ -207,7 +206,10 @@ class FiniteVectorEnv(BaseVectorEnv):
for logger in self._logger:
logger.on_env_all_done()
def reset(self, id=None):
def reset(
self,
id: int | List[int] | np.ndarray | None = None,
) -> np.ndarray:
assert not self._zombie
# Check whether it's guarded by collector_guard()
@ -219,23 +221,23 @@ class FiniteVectorEnv(BaseVectorEnv):
RuntimeWarning,
)
id = self._wrap_id(id)
wrapped_id = self._wrap_id(id)
self._reset_alive_envs()
# ask super to reset alive envs and remap to current index
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
obs = [None] * len(id)
id2idx = {i: k for k, i in enumerate(id)}
request_id = [i for i in wrapped_id if i in self._alive_env_ids]
obs = [None] * len(wrapped_id)
id2idx = {i: k for k, i in enumerate(wrapped_id)}
if request_id:
for i, o in zip(request_id, super().reset(request_id)):
obs[id2idx[i]] = self._postproc_env_obs(o)
for i, o in zip(id, obs):
for i, o in zip(wrapped_id, obs):
if o is None and i in self._alive_env_ids:
self._alive_env_ids.remove(i)
# logging
for i, o in zip(id, obs):
for i, o in zip(wrapped_id, obs):
if i in self._alive_env_ids:
for logger in self._logger:
logger.on_env_reset(i, obs)
@ -248,19 +250,23 @@ class FiniteVectorEnv(BaseVectorEnv):
obs[i] = self._get_default_obs()
if not self._alive_env_ids:
# comment this line so that the env becomes indisposable
# comment this line so that the env becomes indispensable
# self.reset()
self._zombie = True
raise StopIteration
return np.stack(obs)
def step(self, action, id=None):
def step(
self,
action: np.ndarray,
id: int | List[int] | np.ndarray | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert not self._zombie
id = self._wrap_id(id)
id2idx = {i: k for k, i in enumerate(id)}
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
result = [[None, None, False, None] for _ in range(len(id))]
wrapped_id = self._wrap_id(id)
id2idx = {i: k for k, i in enumerate(wrapped_id)}
request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id))
result = [[None, None, False, None] for _ in range(len(wrapped_id))]
# ask super to step alive envs and remap to current index
if request_id:
@ -270,7 +276,7 @@ class FiniteVectorEnv(BaseVectorEnv):
result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0])
# logging
for i, r in zip(id, result):
for i, r in zip(wrapped_id, result):
if i in self._alive_env_ids:
for logger in self._logger:
logger.on_env_step(i, *r)
@ -287,7 +293,8 @@ class FiniteVectorEnv(BaseVectorEnv):
if r[3] is None:
result[i][3] = self._get_default_info()
return list(map(np.stack, zip(*result)))
ret = list(map(np.stack, zip(*result)))
return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
@ -306,7 +313,7 @@ def vectorize_env(
env_factory: Callable[..., gym.Env],
env_type: FiniteEnvType,
concurrency: int,
logger: LogWriter | list[LogWriter],
logger: LogWriter | List[LogWriter],
) -> FiniteVectorEnv:
"""Helper function to create a vector env. Can be used to replace usual VectorEnv.
@ -350,7 +357,7 @@ def vectorize_env(
def env_factory(): ...
vectorize_env(env_factory, ...)
"""
env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = {
env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {
"dummy": FiniteDummyVectorEnv,
"subproc": FiniteSubprocVectorEnv,
"shmem": FiniteShmemVectorEnv,

Просмотреть файл

@ -21,7 +21,7 @@ import logging
from collections import defaultdict
from enum import IntEnum
from pathlib import Path
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence, Callable
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, List, Sequence, Set, Tuple, TypeVar
import numpy as np
import pandas as pd
@ -65,13 +65,13 @@ class LogCollector:
``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe.
"""
_logged: dict[str, tuple[int, Any]]
_logged: Dict[str, Tuple[int, Any]]
_min_loglevel: int
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC):
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
self._min_loglevel = int(min_loglevel)
def reset(self):
def reset(self) -> None:
"""Clear all collected contents."""
self._logged = {}
@ -104,7 +104,10 @@ class LogCollector:
self._add_metric(name, scalar, loglevel)
def add_array(
self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC
self,
name: str,
array: np.ndarray | pd.DataFrame | pd.Series,
loglevel: int | LogLevel = LogLevel.PERIODIC,
) -> None:
"""Add an array with name into logging."""
if loglevel < self._min_loglevel:
@ -127,7 +130,7 @@ class LogCollector:
self._add_metric(name, obj, loglevel)
def logs(self) -> dict[str, np.ndarray]:
def logs(self) -> Dict[str, np.ndarray]:
return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()}
@ -154,16 +157,16 @@ class LogWriter(Generic[ObsType, ActType]):
active_env_ids: Set[int]
"""Active environment ids in vector env."""
episode_lengths: dict[int, int]
episode_lengths: Dict[int, int]
"""Map from environment id to episode length."""
episode_rewards: dict[int, list[float]]
episode_rewards: Dict[int, List[float]]
"""Map from environment id to episode total reward."""
episode_logs: dict[int, list]
episode_logs: Dict[int, list]
"""Map from environment id to episode logs."""
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC):
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
self.loglevel = loglevel
self.global_step = 0
@ -207,11 +210,12 @@ class LogWriter(Generic[ObsType, ActType]):
# These are runtime infos.
# Though they are loaded, I don't think it really helps.
self.active_env_ids = state_dict["active_env_ids"]
self.episode_lenghts = state_dict["episode_lengths"]
self.episode_lengths = state_dict["episode_lengths"]
self.episode_rewards = state_dict["episode_rewards"]
self.episode_logs = state_dict["episode_logs"]
def aggregation(self, array: Sequence[Any], name: str | None = None) -> Any:
@staticmethod
def aggregation(array: Sequence[Any], name: str | None = None) -> Any:
"""Aggregation function from step-wise to episode-wise.
If it's a sequence of float, take the mean.
@ -229,7 +233,7 @@ class LogWriter(Generic[ObsType, ActType]):
else:
return array[0]
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
"""This is triggered at the end of each trajectory.
Parameters
@ -242,7 +246,7 @@ class LogWriter(Generic[ObsType, ActType]):
Logged contents for every steps.
"""
def log_step(self, reward: float, contents: dict[str, Any]) -> None:
def log_step(self, reward: float, contents: Dict[str, Any]) -> None:
"""This is triggered at each step.
Parameters
@ -265,7 +269,7 @@ class LogWriter(Generic[ObsType, ActType]):
# TODO: reward can be a list of list for MARL
self.episode_rewards[env_id].append(rew)
values: dict[str, Any] = {}
values: Dict[str, Any] = {}
for key, (loglevel, value) in info["log"].items():
if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME)
@ -393,11 +397,11 @@ class ConsoleWriter(LogWriter):
def __init__(
self,
log_every_n_episode: int = 20,
total_episodes: int | None = None,
total_episodes: int = None,
float_format: str = ":.4f",
counter_format: str = ":4d",
loglevel: int | LogLevel = LogLevel.PERIODIC,
):
) -> None:
super().__init__(loglevel)
# TODO: support log_every_n_step
self.log_every_n_episode = log_every_n_episode
@ -412,15 +416,15 @@ class ConsoleWriter(LogWriter):
# FIXME: save & reload
def clear(self):
def clear(self) -> None:
super().clear()
# Clear average meters
self.metric_counts: dict[str, int] = defaultdict(int)
self.metric_sums: dict[str, float] = defaultdict(float)
self.metric_counts: Dict[str, int] = defaultdict(int)
self.metric_sums: Dict[str, float] = defaultdict(float)
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
# Aggregate step-wise to episode-wise
episode_wise_contents: dict[str, list] = defaultdict(list)
episode_wise_contents: Dict[str, list] = defaultdict(list)
for step_contents in contents:
for name, value in step_contents.items():
@ -429,7 +433,7 @@ class ConsoleWriter(LogWriter):
# Generate log contents and track them in average-meter.
# This should be done at every step, regardless of periodic or not.
logs: dict[str, float] = {}
logs: Dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values, name) # type: ignore
@ -441,7 +445,7 @@ class ConsoleWriter(LogWriter):
# Only log periodically or at the end
self.console_logger.info(self.generate_log_message(logs))
def generate_log_message(self, logs: dict[str, float]) -> str:
def generate_log_message(self, logs: Dict[str, float]) -> str:
if self.prefix:
msg_prefix = self.prefix + " "
else:
@ -471,29 +475,29 @@ class CsvWriter(LogWriter):
SUPPORTED_TYPES = (float, str, pd.Timestamp)
all_records: list[dict[str, Any]]
all_records: List[Dict[str, Any]]
# FIXME: save & reload
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
super().__init__(loglevel)
self.output_dir = output_dir
self.output_dir.mkdir(exist_ok=True)
def clear(self):
def clear(self) -> None:
super().clear()
self.all_records = []
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
# FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup
episode_wise_contents: dict[str, list] = defaultdict(list)
episode_wise_contents: Dict[str, list] = defaultdict(list)
for step_contents in contents:
for name, value in step_contents.items():
if isinstance(value, self.SUPPORTED_TYPES):
episode_wise_contents[name].append(value)
logs: dict[str, float] = {}
logs: Dict[str, float] = {}
for name, values in episode_wise_contents.items():
logs[name] = self.aggregation(values, name) # type: ignore

Просмотреть файл

@ -2,14 +2,14 @@
# Licensed under the MIT License.
from __future__ import annotations
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Generator, Optional
from abc import ABCMeta, abstractmethod
from typing import Any, Generator, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
from qlib.backtest.exchange import Exchange
from qlib.backtest.position import BasePosition
from typing import Tuple, Union
from typing import Tuple
from ..backtest.decision import BaseTradeDecision
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
@ -207,8 +207,18 @@ class BaseStrategy:
range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype)
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
def post_exe_step(self, execute_result: list) -> None:
"""
A hook for doing sth after the corresponding executor finished its execution.
class RLStrategy(BaseStrategy):
Parameters
----------
execute_result :
the execution result
"""
class RLStrategy(BaseStrategy, metaclass=ABCMeta):
"""RL-based strategy"""
def __init__(
@ -229,14 +239,14 @@ class RLStrategy(BaseStrategy):
self.policy = policy
class RLIntStrategy(RLStrategy):
class RLIntStrategy(RLStrategy, metaclass=ABCMeta):
"""(RL)-based (Strategy) with (Int)erpreter"""
def __init__(
self,
policy,
state_interpreter: Union[dict, StateInterpreter],
action_interpreter: Union[dict, ActionInterpreter],
state_interpreter: dict | StateInterpreter,
action_interpreter: dict | ActionInterpreter,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,

Просмотреть файл

@ -271,7 +271,7 @@ class LocIndexer:
if isinstance(_indexing, IndexData):
_indexing = _indexing.data
assert _indexing.ndim == 1
if _indexing.dtype != np.bool:
if _indexing.dtype != bool:
_indexing = np.array(list(index.index(i) for i in _indexing))
else:
_indexing = index.index(_indexing)
@ -431,7 +431,7 @@ class IndexData(metaclass=index_data_ops_creator):
# The code below could be simpler like methods in __getattribute__
def __invert__(self):
return self.__class__(~self.data.astype(np.bool), *self.indices)
return self.__class__(~self.data.astype(bool), *self.indices)
def abs(self):
"""get the abs of data except np.NaN."""

Просмотреть файл

@ -5,6 +5,8 @@ from random import randint, choice
from pathlib import Path
import re
from typing import Any, Tuple
import gym
import numpy as np
import pandas as pd
@ -24,16 +26,16 @@ from qlib.rl.utils.finite_env import vectorize_env
class SimpleEnv(gym.Env[int, int]):
def __init__(self):
def __init__(self) -> None:
self.logger = LogCollector()
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
def reset(self, *args: Any, **kwargs: Any) -> int:
self.step_count = 0
return 0
def step(self, action: int):
def step(self, action: int) -> Tuple[int, float, bool, dict]:
self.logger.reset()
self.logger.add_scalar("reward", 42.0)
@ -53,6 +55,9 @@ class SimpleEnv(gym.Env[int, int]):
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
def render(self, mode: str = "human") -> None:
pass
class AnyPolicy(BasePolicy):
def forward(self, batch, state=None):
@ -86,7 +91,8 @@ def test_simple_env_logger(caplog):
class SimpleSimulator(Simulator[int, float, float]):
def __init__(self, initial: int, **kwargs) -> None:
def __init__(self, initial: int, **kwargs: Any) -> None:
super(SimpleSimulator, self).__init__(initial, **kwargs)
self.initial = float(initial)
def step(self, action: float) -> None:

Просмотреть файл

@ -0,0 +1,177 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
import pandas as pd
import pytest
from qlib.backtest.decision import Order, OrderDir
from qlib.backtest.executor import NestedExecutor, SimulatorExecutor
from qlib.backtest.utils import CommonInfrastructure
from qlib.contrib.strategy import TWAPStrategy
from qlib.rl.order_execution import CategoricalActionInterpreter
from qlib.rl.order_execution.simulator_qlib import ExchangeConfig, SingleAssetOrderExecutionQlib
TOTAL_POSITION = 2100.0
python_version_request = pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
def is_close(a: float, b: float, epsilon: float = 1e-4) -> bool:
return abs(a - b) <= epsilon
def get_order() -> Order:
return Order(
stock_id="SH600000",
amount=TOTAL_POSITION,
direction=OrderDir.BUY,
start_time=pd.Timestamp("2019-03-04 09:30:00"),
end_time=pd.Timestamp("2019-03-04 14:29:00"),
)
def get_simulator(order: Order) -> SingleAssetOrderExecutionQlib:
def _inner_executor_fn(time_per_step: str, common_infra: CommonInfrastructure) -> NestedExecutor:
return NestedExecutor(
time_per_step=time_per_step,
inner_strategy=TWAPStrategy(),
inner_executor=SimulatorExecutor(
time_per_step="1min",
verbose=False,
trade_type=SimulatorExecutor.TT_SERIAL,
generate_report=False,
common_infra=common_infra,
track_data=True,
),
common_infra=common_infra,
track_data=True,
)
DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "qlib_simulator"
# fmt: off
qlib_config = {
"provider_uri_day": DATA_ROOT_DIR / "qlib_1d",
"provider_uri_1min": DATA_ROOT_DIR / "qlib_1min",
"feature_root_dir": DATA_ROOT_DIR / "qlib_handler_stock",
"feature_columns_today": [
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5",
],
"feature_columns_yesterday": [
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
],
}
# fmt: on
exchange_config = ExchangeConfig(
limit_threshold=("$ask == 0", "$bid == 0"),
deal_price=("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"),
volume_threshold={
"all": ("cum", "0.2 * DayCumsum($volume, '9:30', '14:29')"),
"buy": ("current", "$askV1"),
"sell": ("current", "$bidV1"),
},
open_cost=0.0005,
close_cost=0.0015,
min_cost=5.0,
trade_unit=None,
cash_limit=None,
generate_report=False,
)
return SingleAssetOrderExecutionQlib(
order=order,
time_per_step="30min",
qlib_config=qlib_config,
inner_executor_fn=_inner_executor_fn,
exchange_config=exchange_config,
)
@python_version_request
def test_simulator_first_step():
order = get_order()
simulator = get_simulator(order)
state = simulator.get_state()
assert state.cur_time == pd.Timestamp("2019-03-04 09:30:00")
assert state.position == TOTAL_POSITION
AMOUNT = 300.0
simulator.step(AMOUNT)
state = simulator.get_state()
assert state.cur_time == pd.Timestamp("2019-03-04 10:00:00")
assert state.position == TOTAL_POSITION - AMOUNT
assert len(state.history_exec) == 30
assert state.history_exec.index[0] == pd.Timestamp("2019-03-04 09:30:00")
assert is_close(state.history_exec["market_volume"].iloc[0], 109382.382812)
assert is_close(state.history_exec["market_price"].iloc[0], 149.566483)
assert (state.history_exec["amount"] == AMOUNT / 30).all()
assert (state.history_exec["deal_amount"] == AMOUNT / 30).all()
assert is_close(state.history_exec["trade_price"].iloc[0], 149.566483)
assert is_close(state.history_exec["trade_value"].iloc[0], 1495.664825)
assert is_close(state.history_exec["position"].iloc[0], TOTAL_POSITION - AMOUNT / 30)
# assert state.history_exec["ffr"].iloc[0] == 1 / 60 # FIXME
assert is_close(state.history_steps["market_volume"].iloc[0], 1254848.5756835938)
assert state.history_steps["amount"].iloc[0] == AMOUNT
assert state.history_steps["deal_amount"].iloc[0] == AMOUNT
assert state.history_steps["ffr"].iloc[0] == 1.0
assert is_close(
state.history_steps["pa"].iloc[0] * (1.0 if order.direction == OrderDir.SELL else -1.0),
(state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000,
)
@python_version_request
def test_simulator_stop_twap() -> None:
order = get_order()
simulator = get_simulator(order)
NUM_STEPS = 7
for i in range(NUM_STEPS):
simulator.step(TOTAL_POSITION / NUM_STEPS)
HISTORY_STEP_LENGTH = 30 * NUM_STEPS
state = simulator.get_state()
assert len(state.history_exec) == HISTORY_STEP_LENGTH
assert (state.history_exec["deal_amount"] == TOTAL_POSITION / HISTORY_STEP_LENGTH).all()
assert is_close(state.history_steps["position"].iloc[0], TOTAL_POSITION * (NUM_STEPS - 1) / NUM_STEPS)
assert is_close(state.history_steps["position"].iloc[-1], 0.0)
assert is_close(state.position, 0.0)
assert is_close(state.metrics["ffr"], 1.0)
assert is_close(state.metrics["market_price"], state.backtest_data.get_deal_price().mean())
assert is_close(state.metrics["market_volume"], state.backtest_data.get_volume().sum())
assert is_close(state.metrics["trade_price"], state.metrics["market_price"])
assert is_close(state.metrics["pa"], 0.0)
assert simulator.done()
@python_version_request
def test_interpreter() -> None:
NUM_EXECUTION = 3
order = get_order()
simulator = get_simulator(order)
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
NUM_STEPS = 7
state = simulator.get_state()
position_history = []
for i in range(NUM_STEPS):
simulator.step(interpreter_action(state, 1))
state = simulator.get_state()
position_history.append(state.position)
assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0)
if __name__ == "__main__":
test_simulator_first_step()
test_simulator_stop_twap()
test_interpreter()

Просмотреть файл

@ -9,7 +9,6 @@ from typing import NamedTuple
import numpy as np
import pandas as pd
import pytest
import torch
from tianshou.data import Batch
@ -17,8 +16,8 @@ from qlib.backtest import Order
from qlib.config import C
from qlib.log import set_log_with_config
from qlib.rl.data import pickle_styled
from qlib.rl.trainer import backtest, train
from qlib.rl.order_execution import *
from qlib.rl.trainer import backtest, train
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
@ -38,7 +37,7 @@ CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
def test_pickle_data_inspect():
data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
assert len(data) == 390
data = pickle_styled.load_intraday_processed_data(