зеркало из https://github.com/microsoft/qlib.git
Migrate NeuTrader to Qlib RL (#1169)
* Refine previous version RL codes
* Polish utils/__init__.py
* Draft
* Use | instead of Union
* Simulator & action interpreter
* Test passed
* Migrate to SAOEState & new qlib interpreter
* Black format
* . Revert file_storage change
* Refactor file structure & renaming functions
* Enrich test cases
* Add QlibIntradayBacktestData
* Test interpreter
* Black format
* .
.
.
* Rename receive_execute_result()
* Use indicator to simplify state update
* Format code
* Modify data path
* Adjust file structure
* Minor change
* Add copyright message
* Format code
* Rename util functions
* Add CI
* Pylint issue
* Remove useless code to pass pylint
* Pass mypy
* Mypy issue
* mypy issue
* mypy issue
* Revert "mypy issue"
This reverts commit 8eb1b0174e
.
* mypy issue
* mypy issue
* Fix the numpy version incompatible bug
* Fix a minor typing issue
* Try to skip python 3.7 test for qlib simulator
* Resolve PR comments by Yuge; solve several CI issues.
* Black issue
* Fix a low-level type error
* Change data name
* Resolve PR comments. Leave TODOs in the code base.
Co-authored-by: Young <afe.young@gmail.com>
This commit is contained in:
Родитель
687edd79d0
Коммит
2752bdc92c
|
@ -42,7 +42,7 @@ def get_exchange(
|
|||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Exchange:
|
||||
"""get_exchange
|
||||
|
@ -70,10 +70,10 @@ def get_exchange(
|
|||
min_cost : float
|
||||
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
|
||||
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
|
||||
deal_price: Union[str, Tuple[str], List[str]]
|
||||
deal_price: Union[str, Tuple[str, str], List[str]]
|
||||
The `deal_price` supports following two types of input
|
||||
- <deal_price> : str
|
||||
- (<buy_price>, <sell_price>): Tuple[str] or List[str]
|
||||
- (<buy_price>, <sell_price>): Tuple[str, str] or List[str]
|
||||
|
||||
<deal_price>, <buy_price> or <sell_price> := <price>
|
||||
<price> := str
|
||||
|
|
|
@ -4,10 +4,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import time
|
||||
from enum import IntEnum
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import Generic, List, TYPE_CHECKING, Any, ClassVar, Optional, Tuple, TypeVar, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, List, Optional, Tuple, TypeVar, Union, cast
|
||||
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from qlib.data.data import Cal
|
||||
|
@ -23,7 +24,6 @@ from dataclasses import dataclass
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
DecisionType = TypeVar("DecisionType")
|
||||
|
||||
|
||||
|
@ -182,8 +182,8 @@ class OrderHelper:
|
|||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time if start_time is not None else pd.Timestamp(start_time),
|
||||
end_time=end_time if end_time is not None else pd.Timestamp(end_time),
|
||||
start_time=None if start_time is None else pd.Timestamp(start_time),
|
||||
end_time=None if end_time is None else pd.Timestamp(end_time),
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
|
@ -249,7 +249,7 @@ class IdxTradeRange(TradeRange):
|
|||
class TradeRangeByTime(TradeRange):
|
||||
"""This is a helper function for make decisions"""
|
||||
|
||||
def __init__(self, start_time: str, end_time: str) -> None:
|
||||
def __init__(self, start_time: str | time, end_time: str | time) -> None:
|
||||
"""
|
||||
This is a callable class.
|
||||
|
||||
|
@ -259,13 +259,13 @@ class TradeRangeByTime(TradeRange):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
start_time : str
|
||||
start_time : str | time
|
||||
e.g. "9:30"
|
||||
end_time : str
|
||||
end_time : str | time
|
||||
e.g. "14:30"
|
||||
"""
|
||||
self.start_time = pd.Timestamp(start_time).time()
|
||||
self.end_time = pd.Timestamp(end_time).time()
|
||||
self.start_time = pd.Timestamp(start_time).time() if isinstance(start_time, str) else start_time
|
||||
self.end_time = pd.Timestamp(end_time).time() if isinstance(end_time, str) else end_time
|
||||
assert self.start_time < self.end_time
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
|
||||
|
@ -535,7 +535,12 @@ class TradeDecisionWO(BaseTradeDecision[Order]):
|
|||
Besides, the time_range is also included.
|
||||
"""
|
||||
|
||||
def __init__(self, order_list: List[object], strategy: BaseStrategy, trade_range: Tuple[int, int] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
order_list: List[Order],
|
||||
strategy: BaseStrategy,
|
||||
trade_range: Union[Tuple[int, int], TradeRange] = None,
|
||||
) -> None:
|
||||
super().__init__(strategy, trade_range=trade_range)
|
||||
self.order_list = cast(List[Order], order_list)
|
||||
start, end = strategy.trade_calendar.get_step_time()
|
||||
|
|
|
@ -32,7 +32,7 @@ class Exchange:
|
|||
start_time: Union[pd.Timestamp, str] = None,
|
||||
end_time: Union[pd.Timestamp, str] = None,
|
||||
codes: Union[list, str] = "all",
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
deal_price: Union[str, Tuple[str, str], List[str]] = None,
|
||||
subscribe_fields: list = [],
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
volume_threshold: Union[tuple, dict] = None,
|
||||
|
@ -448,9 +448,9 @@ class Exchange:
|
|||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: Optional[str] = "sum",
|
||||
) -> float:
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
||||
return cast(float, self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method))
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
|
||||
|
||||
def get_deal_price(
|
||||
self,
|
||||
|
@ -459,7 +459,7 @@ class Exchange:
|
|||
end_time: pd.Timestamp,
|
||||
direction: OrderDir,
|
||||
method: Optional[str] = "ts_data_last",
|
||||
) -> float:
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
if direction == OrderDir.SELL:
|
||||
pstr = self.sell_price
|
||||
elif direction == OrderDir.BUY:
|
||||
|
@ -472,7 +472,7 @@ class Exchange:
|
|||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
|
||||
self.logger.warning(f"setting deal_price to close price")
|
||||
deal_price = self.get_close(stock_id, start_time, end_time, method)
|
||||
return cast(float, deal_price)
|
||||
return deal_price
|
||||
|
||||
def get_factor(
|
||||
self,
|
||||
|
@ -832,8 +832,11 @@ class Exchange:
|
|||
:param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float}
|
||||
:return: trade_price, trade_val, trade_cost
|
||||
"""
|
||||
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
|
||||
total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price
|
||||
trade_price = cast(
|
||||
float,
|
||||
self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction),
|
||||
)
|
||||
total_trade_val = cast(float, self.get_volume(order.stock_id, order.start_time, order.end_time)) * trade_price
|
||||
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
|
||||
order.deal_amount = order.amount # set to full amount and clip it step by step
|
||||
# Clipping amount first
|
||||
|
|
|
@ -484,6 +484,7 @@ class NestedExecutor(BaseExecutor):
|
|||
inner_exe_res :
|
||||
the execution result of inner task
|
||||
"""
|
||||
self.inner_strategy.post_exe_step(inner_exe_res)
|
||||
|
||||
def get_all_executors(self) -> List[BaseExecutor]:
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
|
|
|
@ -102,11 +102,22 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage):
|
|||
self._freq_file_cache = freq
|
||||
return self._freq_file_cache
|
||||
|
||||
def _read_calendar(self, skip_rows: int = 0, n_rows: int = None) -> List[CalVT]:
|
||||
def _read_calendar(self) -> List[CalVT]:
|
||||
# NOTE:
|
||||
# if we want to accelerate partial reading calendar
|
||||
# we can add parameters like `skip_rows: int = 0, n_rows: int = None` to the interface.
|
||||
# Currently, it is not supported for the txt-based calendar
|
||||
|
||||
if not self.uri.exists():
|
||||
self._write_calendar(values=[])
|
||||
with self.uri.open("rb") as fp:
|
||||
return [str(x) for x in np.loadtxt(fp, str, skiprows=skip_rows, max_rows=n_rows, encoding="utf-8")]
|
||||
|
||||
with self.uri.open("r") as fp:
|
||||
res = []
|
||||
for line in fp.readlines():
|
||||
line = line.strip()
|
||||
if len(line) > 0:
|
||||
res.append(line)
|
||||
return res
|
||||
|
||||
def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"):
|
||||
with self.uri.open(mode=mode) as fp:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, TYPE_CHECKING, TypeVar
|
||||
from typing import Optional, TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
|
@ -21,7 +21,7 @@ AuxInfoType = TypeVar("AuxInfoType")
|
|||
class AuxiliaryInfoCollector(Generic[StateType, AuxInfoType]):
|
||||
"""Override this class to collect customized auxiliary information from environment."""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@final
|
||||
def __call__(self, simulator_state: StateType) -> AuxInfoType:
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import Exchange, Order
|
||||
from .pickle_styled import IntradayBacktestData
|
||||
|
||||
|
||||
class QlibIntradayBacktestData(IntradayBacktestData):
|
||||
"""Backtest data for Qlib simulator"""
|
||||
|
||||
def __init__(self, order: Order, exchange: Exchange, start_time: pd.Timestamp, end_time: pd.Timestamp) -> None:
|
||||
super(QlibIntradayBacktestData, self).__init__()
|
||||
self._order = order
|
||||
self._exchange = exchange
|
||||
self._start_time = start_time
|
||||
self._end_time = end_time
|
||||
|
||||
self._deal_price = cast(
|
||||
pd.Series,
|
||||
self._exchange.get_deal_price(
|
||||
self._order.stock_id,
|
||||
self._start_time,
|
||||
self._end_time,
|
||||
direction=self._order.direction,
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
self._volume = cast(
|
||||
pd.Series,
|
||||
self._exchange.get_volume(
|
||||
self._order.stock_id,
|
||||
self._start_time,
|
||||
self._end_time,
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Order: {self._order}, Exchange: {self._exchange}, "
|
||||
f"Start time: {self._start_time}, End time: {self._end_time}"
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._deal_price)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
return self._deal_price
|
||||
|
||||
def get_volume(self) -> pd.Series:
|
||||
return self._volume
|
||||
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)])
|
|
@ -19,19 +19,19 @@ This file shows resemblence to qlib.backtest.high_performance_ds. We might merge
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import List, Sequence, cast
|
||||
from pathlib import Path
|
||||
from typing import List, Sequence, cast
|
||||
|
||||
import cachetools
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from cachetools.keys import hashkey
|
||||
|
||||
from qlib.backtest.decision import OrderDir, Order
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
||||
DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
|
||||
"""Several ad-hoc deal price.
|
||||
``bid_or_ask``: If sell, use column ``$bid0``; if buy, use column ``$ask0``.
|
||||
|
@ -40,7 +40,7 @@ DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"]
|
|||
"""
|
||||
|
||||
|
||||
def _infer_processed_data_column_names(shape: int) -> list[str]:
|
||||
def _infer_processed_data_column_names(shape: int) -> List[str]:
|
||||
if shape == 16:
|
||||
return [
|
||||
"$open",
|
||||
|
@ -87,7 +87,36 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
|
|||
|
||||
|
||||
class IntradayBacktestData:
|
||||
"""Raw market data that is often used in backtesting (thus called BacktestData)."""
|
||||
"""
|
||||
Raw market data that is often used in backtesting (thus called BacktestData).
|
||||
|
||||
Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest
|
||||
data type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_volume(self) -> pd.Series:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_time_index(self) -> pd.DatetimeIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SimpleIntradayBacktestData(IntradayBacktestData):
|
||||
"""Backtest data for simple simulator"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -95,8 +124,10 @@ class IntradayBacktestData:
|
|||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
deal_price: DealPriceType = "close",
|
||||
order_dir: int | None = None,
|
||||
):
|
||||
order_dir: int = None,
|
||||
) -> None:
|
||||
super(SimpleIntradayBacktestData, self).__init__()
|
||||
|
||||
backtest = _read_pickle(data_dir / stock_id)
|
||||
backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]]
|
||||
|
||||
|
@ -105,13 +136,13 @@ class IntradayBacktestData:
|
|||
|
||||
self.data: pd.DataFrame = backtest
|
||||
self.deal_price_type: DealPriceType = deal_price
|
||||
self.order_dir: int | None = order_dir
|
||||
self.order_dir = order_dir
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.data})"
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get_deal_price(self) -> pd.Series:
|
||||
|
@ -162,7 +193,14 @@ class IntradayProcessedData:
|
|||
"""Processed data for "yesterday".
|
||||
Number of records must be ``time_length``, and columns must be ``feature_dim``."""
|
||||
|
||||
def __init__(self, data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> None:
|
||||
proc = _read_pickle(data_dir / stock_id)
|
||||
# We have to infer the names here because,
|
||||
# unfortunately they are not included in the original data.
|
||||
|
@ -190,16 +228,20 @@ class IntradayProcessedData:
|
|||
assert len(self.today.columns) == len(self.yesterday.columns) == feature_dim
|
||||
assert len(self.today) == len(self.yesterday) == time_length
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"):
|
||||
return f"{self.__class__.__name__}({self.today}, {self.yesterday})"
|
||||
|
||||
|
||||
@lru_cache(maxsize=100) # 100 * 50K = 5MB
|
||||
def load_intraday_backtest_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, deal_price: DealPriceType = "close", order_dir: int | None = None
|
||||
) -> IntradayBacktestData:
|
||||
return IntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
def load_simple_intraday_backtest_data(
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
deal_price: DealPriceType = "close",
|
||||
order_dir: int = None,
|
||||
) -> SimpleIntradayBacktestData:
|
||||
return SimpleIntradayBacktestData(data_dir, stock_id, date, deal_price, order_dir)
|
||||
|
||||
|
||||
@cachetools.cached( # type: ignore
|
||||
|
@ -207,13 +249,19 @@ def load_intraday_backtest_data(
|
|||
key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date),
|
||||
)
|
||||
def load_intraday_processed_data(
|
||||
data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index
|
||||
data_dir: Path,
|
||||
stock_id: str,
|
||||
date: pd.Timestamp,
|
||||
feature_dim: int,
|
||||
time_index: pd.Index,
|
||||
) -> IntradayProcessedData:
|
||||
return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index)
|
||||
|
||||
|
||||
def load_orders(
|
||||
order_path: Path, start_time: pd.Timestamp | None = None, end_time: pd.Timestamp | None = None
|
||||
order_path: Path,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> Sequence[Order]:
|
||||
"""Load orders, and set start time and end time for the orders."""
|
||||
|
||||
|
@ -251,7 +299,7 @@ def load_orders(
|
|||
OrderDir(int(row["order_type"])),
|
||||
row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second),
|
||||
row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return orders
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# TODO: find a better way to organize contents under this module.
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
# TODO: In the future we should merge the dataclass-based config with Qlib's dict-based config.
|
||||
@dataclass
|
||||
class ExchangeConfig:
|
||||
limit_threshold: Union[float, Tuple[str, str]]
|
||||
deal_price: Union[str, Tuple[str, str]]
|
||||
volume_threshold: dict
|
||||
open_cost: float = 0.0005
|
||||
close_cost: float = 0.0015
|
||||
min_cost: float = 5.0
|
||||
trade_unit: Optional[float] = 100.0
|
||||
cash_limit: Optional[Union[Path, float]] = None
|
||||
generate_report: bool = False
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import collections
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, pool_size: int = 200):
|
||||
self.pool_size = pool_size
|
||||
self.contents: dict = {}
|
||||
self.keys: collections.deque = collections.deque()
|
||||
|
||||
def put(self, key, item):
|
||||
if self.has(key):
|
||||
self.keys.remove(key)
|
||||
self.keys.append(key)
|
||||
self.contents[key] = item
|
||||
while len(self.contents) > self.pool_size:
|
||||
self.contents.pop(self.keys.popleft())
|
||||
|
||||
def get(self, key):
|
||||
return self.contents[key]
|
||||
|
||||
def has(self, key):
|
||||
return key in self.contents
|
||||
|
||||
|
||||
class DataWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
feature_dataset: DatasetH,
|
||||
backtest_dataset: DatasetH,
|
||||
columns_today: List[str],
|
||||
columns_yesterday: List[str],
|
||||
_internal: bool = False,
|
||||
):
|
||||
assert _internal, "Init function of data wrapper is for internal use only."
|
||||
|
||||
self.feature_dataset = feature_dataset
|
||||
self.backtest_dataset = backtest_dataset
|
||||
self.columns_today = columns_today
|
||||
self.columns_yesterday = columns_yesterday
|
||||
|
||||
# TODO: We might have the chance to merge them.
|
||||
self.feature_cache = LRUCache()
|
||||
self.backtest_cache = LRUCache()
|
||||
|
||||
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
|
||||
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
|
||||
|
||||
if backtest:
|
||||
dataset = self.backtest_dataset
|
||||
cache = self.backtest_cache
|
||||
else:
|
||||
dataset = self.feature_dataset
|
||||
cache = self.feature_cache
|
||||
|
||||
if cache.has((start_time, end_time, stock_id)):
|
||||
return cache.get((start_time, end_time, stock_id))
|
||||
data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None)
|
||||
cache.put((start_time, end_time, stock_id), data)
|
||||
return data
|
||||
|
||||
|
||||
def init_qlib(config: dict, part: Optional[str] = None) -> None:
|
||||
provider_uri_map = {
|
||||
"day": config["provider_uri_day"].as_posix(),
|
||||
"1min": config["provider_uri_1min"].as_posix(),
|
||||
}
|
||||
qlib.init(
|
||||
region=REG_CN,
|
||||
auto_mount=False,
|
||||
custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum],
|
||||
expression_cache=None,
|
||||
calendar_provider={
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
},
|
||||
},
|
||||
},
|
||||
feature_provider={
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
},
|
||||
},
|
||||
},
|
||||
provider_uri=provider_uri_map,
|
||||
kernels=1,
|
||||
redis_port=-1,
|
||||
clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance
|
||||
)
|
|
@ -3,13 +3,13 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar, Generic, Any
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
from .simulator import StateType, ActType
|
||||
from .simulator import ActType, StateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .utils.env_wrapper import EnvWrapper
|
||||
|
@ -40,7 +40,7 @@ class Interpreter:
|
|||
class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@property
|
||||
def observation_space(self) -> gym.Space:
|
||||
|
@ -74,7 +74,7 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter):
|
|||
class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
|
||||
env: "EnvWrapper" | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@property
|
||||
def action_space(self) -> gym.Space:
|
||||
|
@ -141,10 +141,10 @@ def _gym_space_contains(space: gym.Space, x: Any) -> None:
|
|||
|
||||
|
||||
class GymSpaceValidationError(Exception):
|
||||
def __init__(self, message: str, space: gym.Space, x: Any):
|
||||
def __init__(self, message: str, space: gym.Space, x: Any) -> None:
|
||||
self.message = message
|
||||
self.space = space
|
||||
self.x = x
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.message}\n Space: {self.space}\n Sample: {self.x}"
|
||||
|
|
|
@ -5,15 +5,15 @@ from __future__ import annotations
|
|||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from typing import Any, List, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from gym import spaces
|
||||
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .simulator_simple import SAOEState
|
||||
|
@ -99,18 +99,18 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]):
|
|||
"data_processed": self._mask_future_info(processed.today, state.cur_time),
|
||||
"data_processed_prev": processed.yesterday,
|
||||
"acquiring": state.order.direction == state.order.BUY,
|
||||
"cur_tick": min(np.sum(state.ticks_index < state.cur_time), self.data_ticks - 1),
|
||||
"cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1),
|
||||
"cur_step": min(self.env.status["cur_step"], self.max_step - 1),
|
||||
"num_step": self.max_step,
|
||||
"target": state.order.amount,
|
||||
"position": state.position,
|
||||
"position_history": position_history[: self.max_step],
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> spaces.Dict:
|
||||
space = {
|
||||
"data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
|
||||
"data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)),
|
||||
|
@ -147,11 +147,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
|||
The key list is not full. You can add more if more information is needed by your policy.
|
||||
"""
|
||||
|
||||
def __init__(self, max_step: int):
|
||||
def __init__(self, max_step: int) -> None:
|
||||
self.max_step = max_step
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> spaces.Dict:
|
||||
space = {
|
||||
"acquiring": spaces.Discrete(2),
|
||||
"cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32),
|
||||
|
@ -165,13 +165,11 @@ class CurrentStepStateInterpreter(StateInterpreter[SAOEState, CurrentStateObs]):
|
|||
assert self.env is not None
|
||||
assert self.env.status["cur_step"] <= self.max_step
|
||||
obs = CurrentStateObs(
|
||||
{
|
||||
"acquiring": state.order.direction == state.order.BUY,
|
||||
"cur_step": self.env.status["cur_step"],
|
||||
"num_step": self.max_step,
|
||||
"target": state.order.amount,
|
||||
"position": state.position,
|
||||
}
|
||||
acquiring=state.order.direction == state.order.BUY,
|
||||
cur_step=self.env.status["cur_step"],
|
||||
num_step=self.max_step,
|
||||
target=state.order.amount,
|
||||
position=state.position,
|
||||
)
|
||||
return obs
|
||||
|
||||
|
@ -188,7 +186,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
|||
i.e., $[0, 1/n, 2/n, \\ldots, n/n]$.
|
||||
"""
|
||||
|
||||
def __init__(self, values: int | list[float]):
|
||||
def __init__(self, values: int | List[float]) -> None:
|
||||
if isinstance(values, int):
|
||||
values = [i / values for i in range(0, values + 1)]
|
||||
self.action_values = values
|
||||
|
@ -203,7 +201,7 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]):
|
|||
|
||||
|
||||
class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]):
|
||||
"""Convert a continous ratio to deal amount.
|
||||
"""Convert a continuous ratio to deal amount.
|
||||
|
||||
The ratio is relative to TWAP on the remainder of the day.
|
||||
For example, there are 5 steps left, and the left position is 300.
|
||||
|
|
|
@ -3,13 +3,14 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
from typing import List, Tuple, cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tianshou.data import Batch
|
||||
|
||||
from qlib.typehint import Literal
|
||||
|
||||
from .interpreter import FullHistoryObs
|
||||
|
||||
__all__ = ["Recurrent"]
|
||||
|
@ -18,7 +19,7 @@ __all__ = ["Recurrent"]
|
|||
class Recurrent(nn.Module):
|
||||
"""The network architecture proposed in `OPD <https://seqml.github.io/opd/opd_aaai21_supplement.pdf>`_.
|
||||
|
||||
At every timestep the input of policy network is divided into two parts,
|
||||
At every time step the input of policy network is divided into two parts,
|
||||
the public variables and the private variables. which are handled by ``raw_rnn``
|
||||
and ``pri_rnn`` in this network, respectively.
|
||||
|
||||
|
@ -33,7 +34,7 @@ class Recurrent(nn.Module):
|
|||
output_dim: int = 32,
|
||||
rnn_type: Literal["rnn", "lstm", "gru"] = "gru",
|
||||
rnn_num_layers: int = 1,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
|
@ -62,10 +63,10 @@ class Recurrent(nn.Module):
|
|||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def _init_extra_branches(self):
|
||||
def _init_extra_branches(self) -> None:
|
||||
pass
|
||||
|
||||
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
||||
bs, _, data_dim = obs["data_processed"].size()
|
||||
data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1)
|
||||
cur_step = obs["cur_step"].long()
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Dict, Generator, Iterable, Optional, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from gym.spaces import Discrete
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.policy import PPOPolicy, BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.policy import BasePolicy, PPOPolicy
|
||||
|
||||
__all__ = ["AllOne", "PPO"]
|
||||
|
||||
|
@ -18,29 +19,39 @@ __all__ = ["AllOne", "PPO"]
|
|||
# baselines #
|
||||
|
||||
|
||||
class NonlearnablePolicy(BasePolicy):
|
||||
class NonLearnablePolicy(BasePolicy):
|
||||
"""Tianshou's BasePolicy with empty ``learn`` and ``process_fn``.
|
||||
|
||||
This could be moved outside in future.
|
||||
"""
|
||||
|
||||
def __init__(self, obs_space: gym.Space, action_space: gym.Space):
|
||||
def __init__(self, obs_space: gym.Space, action_space: gym.Space) -> None:
|
||||
super().__init__()
|
||||
|
||||
def learn(self, batch, batch_size, repeat):
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
def process_fn(
|
||||
self,
|
||||
batch: Batch,
|
||||
buffer: ReplayBuffer,
|
||||
indices: np.ndarray,
|
||||
) -> Batch:
|
||||
pass
|
||||
|
||||
|
||||
class AllOne(NonlearnablePolicy):
|
||||
class AllOne(NonLearnablePolicy):
|
||||
"""Forward returns a batch full of 1.
|
||||
|
||||
Useful when implementing some baselines (e.g., TWAP).
|
||||
"""
|
||||
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: dict | Batch | np.ndarray = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
return Batch(act=np.full(len(batch), 1.0), state=state)
|
||||
|
||||
|
||||
|
@ -48,24 +59,34 @@ class AllOne(NonlearnablePolicy):
|
|||
|
||||
|
||||
class PPOActor(nn.Module):
|
||||
def __init__(self, extractor: nn.Module, action_dim: int):
|
||||
def __init__(self, extractor: nn.Module, action_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1))
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
obs: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
info: dict = {},
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
feature = self.extractor(to_torch(obs, device=auto_device(self)))
|
||||
out = self.layer_out(feature)
|
||||
return out, state
|
||||
|
||||
|
||||
class PPOCritic(nn.Module):
|
||||
def __init__(self, extractor: nn.Module):
|
||||
def __init__(self, extractor: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.extractor = extractor
|
||||
self.value_out = nn.Linear(cast(int, extractor.output_dim), 1)
|
||||
|
||||
def forward(self, obs, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
obs: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
info: dict = {},
|
||||
) -> torch.Tensor:
|
||||
feature = self.extractor(to_torch(obs, device=auto_device(self)))
|
||||
return self.value_out(feature).squeeze(dim=-1)
|
||||
|
||||
|
@ -93,18 +114,20 @@ class PPO(PPOPolicy):
|
|||
max_grad_norm: float = 100.0,
|
||||
reward_normalization: bool = True,
|
||||
eps_clip: float = 0.3,
|
||||
value_clip: float = True,
|
||||
value_clip: bool = True,
|
||||
vf_coef: float = 1.0,
|
||||
gae_lambda: float = 1.0,
|
||||
max_batchsize: int = 256,
|
||||
max_batch_size: int = 256,
|
||||
deterministic_eval: bool = True,
|
||||
weight_file: Optional[Path] = None,
|
||||
):
|
||||
) -> None:
|
||||
assert isinstance(action_space, Discrete)
|
||||
actor = PPOActor(network, action_space.n)
|
||||
critic = PPOCritic(network)
|
||||
optimizer = torch.optim.Adam(
|
||||
chain_dedup(actor.parameters(), critic.parameters()), lr=lr, weight_decay=weight_decay
|
||||
chain_dedup(actor.parameters(), critic.parameters()),
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
super().__init__(
|
||||
actor,
|
||||
|
@ -118,7 +141,7 @@ class PPO(PPOPolicy):
|
|||
value_clip=value_clip,
|
||||
vf_coef=vf_coef,
|
||||
gae_lambda=gae_lambda,
|
||||
max_batchsize=max_batchsize,
|
||||
max_batchsize=max_batch_size,
|
||||
deterministic_eval=deterministic_eval,
|
||||
observation_space=obs_space,
|
||||
action_space=action_space,
|
||||
|
@ -136,7 +159,7 @@ def auto_device(module: nn.Module) -> torch.device:
|
|||
return torch.device("cpu") # fallback to cpu
|
||||
|
||||
|
||||
def load_weight(policy, path):
|
||||
def load_weight(policy: nn.Module, path: Path) -> None:
|
||||
assert isinstance(policy, nn.Module), "Policy has to be an nn.Module to load weight."
|
||||
loaded_weight = torch.load(path, map_location="cpu")
|
||||
try:
|
||||
|
@ -149,7 +172,7 @@ def load_weight(policy, path):
|
|||
policy.load_state_dict(loaded_weight)
|
||||
|
||||
|
||||
def chain_dedup(*iterables):
|
||||
def chain_dedup(*iterables: Iterable) -> Generator[Any, None, None]:
|
||||
seen = set()
|
||||
for iterable in iterables:
|
||||
for i in iterable:
|
||||
|
|
|
@ -6,9 +6,10 @@ from __future__ import annotations
|
|||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.rl.reward import Reward
|
||||
|
||||
from .simulator_simple import SAOEState, SAOEMetrics
|
||||
from .simulator_simple import SAOEMetrics, SAOEState
|
||||
|
||||
__all__ = ["PAPenaltyReward"]
|
||||
|
||||
|
|
|
@ -1,4 +1,424 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Placeholder for qlib-based simulator."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, cast, Generator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderHelper, TradeDecisionWO, TradeRange, TradeRangeByTime
|
||||
from qlib.backtest.executor import BaseExecutor, NestedExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data.exchange_wrapper import QlibIntradayBacktestData
|
||||
from qlib.rl.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.from_neutrader.feature import init_qlib
|
||||
from qlib.rl.order_execution.simulator_simple import SAOEMetrics, SAOEState
|
||||
from qlib.rl.order_execution.utils import (
|
||||
dataframe_append,
|
||||
get_common_infra,
|
||||
get_portfolio_and_indicator,
|
||||
get_ticks_slice,
|
||||
price_advantage,
|
||||
)
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
|
||||
|
||||
class DecomposedStrategy(BaseStrategy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.execute_order: Optional[Order] = None
|
||||
self.execute_result: List[Tuple[Order, float, float, float]] = []
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
|
||||
# Once the following line is executed, this DecomposedStrategy (self) will be yielded to the outside
|
||||
# of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`,
|
||||
# the sent item will be captured by `exec_vol`. The outside policy could communicate with the inner
|
||||
# level strategy through this way.
|
||||
exec_vol = yield self
|
||||
|
||||
oh = self.trade_exchange.get_order_helper()
|
||||
order = oh.create(self._order.stock_id, exec_vol, self._order.direction)
|
||||
|
||||
self.execute_order = order
|
||||
|
||||
return TradeDecisionWO([order], self)
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
return outer_trade_decision
|
||||
|
||||
def post_exe_step(self, execute_result: list) -> None:
|
||||
self.execute_result = execute_result
|
||||
|
||||
def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs: Any) -> None:
|
||||
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
order_list = outer_trade_decision.order_list
|
||||
assert len(order_list) == 1
|
||||
self._order = order_list[0]
|
||||
|
||||
|
||||
class SingleOrderStrategy(BaseStrategy):
|
||||
# this logic is copied from FileOrderStrategy
|
||||
def __init__(
|
||||
self,
|
||||
common_infra: CommonInfrastructure,
|
||||
order: Order,
|
||||
trade_range: TradeRange,
|
||||
instrument: str,
|
||||
) -> None:
|
||||
super().__init__(common_infra=common_infra)
|
||||
self._order = order
|
||||
self._trade_range = trade_range
|
||||
self._instrument = instrument
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
return outer_trade_decision
|
||||
|
||||
def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionWO:
|
||||
oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper()
|
||||
order_list = [
|
||||
oh.create(
|
||||
code=self._instrument,
|
||||
amount=self._order.amount,
|
||||
direction=self._order.direction,
|
||||
),
|
||||
]
|
||||
return TradeDecisionWO(order_list, self, self._trade_range)
|
||||
|
||||
|
||||
# TODO: move these to the configuration files
|
||||
FINEST_GRANULARITY = "1min"
|
||||
COARSEST_GRANULARITY = "1day"
|
||||
|
||||
|
||||
class StateMaintainer:
|
||||
"""
|
||||
Maintain states of the environment.
|
||||
|
||||
Example usage::
|
||||
|
||||
maintainer = StateMaintainer(...) # in reset
|
||||
maintainer.update(...) # in step
|
||||
# get states in get_state from maintainer
|
||||
"""
|
||||
|
||||
def __init__(self, order: Order, time_per_step: str, tick_index: pd.DatetimeIndex, twap_price: float) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.position = order.amount
|
||||
self._order = order
|
||||
self._time_per_step = time_per_step
|
||||
self._tick_index = tick_index
|
||||
self._twap_price = twap_price
|
||||
|
||||
metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member
|
||||
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics: Optional[SAOEMetrics] = None
|
||||
|
||||
def update(
|
||||
self,
|
||||
inner_executor: BaseExecutor,
|
||||
inner_strategy: DecomposedStrategy,
|
||||
done: bool,
|
||||
all_indicators: dict,
|
||||
) -> None:
|
||||
execute_order = inner_strategy.execute_order
|
||||
execute_result = inner_strategy.execute_result
|
||||
exec_vol = np.array([e[0].deal_amount for e in execute_result])
|
||||
num_step = len(execute_result)
|
||||
|
||||
assert execute_order is not None
|
||||
|
||||
if num_step == 0:
|
||||
market_volume = np.array([])
|
||||
market_price = np.array([])
|
||||
datetime_list = pd.DatetimeIndex([])
|
||||
else:
|
||||
market_volume = np.array(
|
||||
inner_executor.trade_exchange.get_volume(
|
||||
execute_order.stock_id,
|
||||
execute_result[0][0].start_time,
|
||||
execute_result[-1][0].start_time,
|
||||
method=None,
|
||||
),
|
||||
)
|
||||
|
||||
trade_value = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["value"].values
|
||||
deal_amount = all_indicators[FINEST_GRANULARITY].iloc[-num_step:]["deal_amount"].values
|
||||
market_price = trade_value / deal_amount
|
||||
|
||||
datetime_list = all_indicators[FINEST_GRANULARITY].index[-num_step:]
|
||||
|
||||
assert market_price.shape == market_volume.shape == exec_vol.shape
|
||||
|
||||
self.history_exec = dataframe_append(
|
||||
self.history_exec,
|
||||
self._collect_multi_order_metric(
|
||||
order=self._order,
|
||||
datetime=datetime_list,
|
||||
market_vol=market_volume,
|
||||
market_price=market_price,
|
||||
exec_vol=exec_vol,
|
||||
pa=all_indicators[self._time_per_step].iloc[-1]["pa"],
|
||||
),
|
||||
)
|
||||
|
||||
self.history_steps = dataframe_append(
|
||||
self.history_steps,
|
||||
[
|
||||
self._collect_single_order_metric(
|
||||
execute_order,
|
||||
execute_order.start_time,
|
||||
market_volume,
|
||||
market_price,
|
||||
exec_vol.sum(),
|
||||
exec_vol,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if done:
|
||||
self.metrics = self._collect_single_order_metric(
|
||||
self._order,
|
||||
self._tick_index[0], # start time
|
||||
self.history_exec["market_volume"],
|
||||
self.history_exec["market_price"],
|
||||
self.history_steps["amount"].sum(),
|
||||
self.history_exec["deal_amount"],
|
||||
)
|
||||
|
||||
# TODO: check whether we need this. Can we get this information from Account?
|
||||
# Do this at the end
|
||||
self.position -= exec_vol.sum()
|
||||
|
||||
def _collect_multi_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
exec_vol: np.ndarray,
|
||||
pa: float,
|
||||
) -> SAOEMetrics:
|
||||
return SAOEMetrics(
|
||||
# It should have the same keys with SAOEMetrics,
|
||||
# but the values do not necessarily have the annotated type.
|
||||
# Some values could be vectorized (e.g., exec_vol).
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol,
|
||||
market_price=market_price,
|
||||
amount=exec_vol,
|
||||
inner_amount=exec_vol,
|
||||
deal_amount=exec_vol,
|
||||
trade_price=market_price,
|
||||
trade_value=market_price * exec_vol,
|
||||
position=self.position - np.cumsum(exec_vol),
|
||||
ffr=exec_vol / order.amount,
|
||||
pa=pa,
|
||||
)
|
||||
|
||||
def _collect_single_order_metric(
|
||||
self,
|
||||
order: Order,
|
||||
datetime: pd.Timestamp,
|
||||
market_vol: np.ndarray,
|
||||
market_price: np.ndarray,
|
||||
amount: float, # intended to trade such amount
|
||||
exec_vol: np.ndarray,
|
||||
) -> SAOEMetrics:
|
||||
assert len(market_vol) == len(market_price) == len(exec_vol)
|
||||
|
||||
if np.abs(np.sum(exec_vol)) < EPS:
|
||||
exec_avg_price = 0.0
|
||||
else:
|
||||
exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan
|
||||
if hasattr(exec_avg_price, "item"): # could be numpy scalar
|
||||
exec_avg_price = exec_avg_price.item() # type: ignore
|
||||
|
||||
exec_sum = exec_vol.sum()
|
||||
return SAOEMetrics(
|
||||
stock_id=order.stock_id,
|
||||
datetime=datetime,
|
||||
direction=order.direction,
|
||||
market_volume=market_vol.sum(),
|
||||
market_price=market_price.mean() if len(market_price) > 0 else np.nan,
|
||||
amount=amount,
|
||||
inner_amount=exec_sum,
|
||||
deal_amount=exec_sum, # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position - exec_sum,
|
||||
ffr=float(exec_sum / order.amount),
|
||||
pa=price_advantage(exec_avg_price, self._twap_price, order.direction),
|
||||
)
|
||||
|
||||
|
||||
class SingleAssetOrderExecutionQlib(Simulator[Order, SAOEState, float]):
|
||||
"""Single-asset order execution (SAOE) simulator which is implemented based on Qlib backtest tools.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
order (Order):
|
||||
The seed to start an SAOE simulator is an order.
|
||||
time_per_step (str):
|
||||
A string to describe the time granularity of each step. Current support "1min", "30min", and "1day"
|
||||
qlib_config (dict):
|
||||
Configuration used to initialize Qlib.
|
||||
inner_executor_fn (Callable[[str, CommonInfrastructure], BaseExecutor]):
|
||||
Function used to get the inner level executor.
|
||||
exchange_config (ExchangeConfig):
|
||||
Configuration used to create the Exchange instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
order: Order,
|
||||
time_per_step: str, # "1min", "30min", "1day"
|
||||
qlib_config: dict,
|
||||
inner_executor_fn: Callable[[str, CommonInfrastructure], BaseExecutor],
|
||||
exchange_config: ExchangeConfig,
|
||||
) -> None:
|
||||
assert time_per_step in ("1min", "30min", "1day")
|
||||
|
||||
super().__init__(initial=order)
|
||||
|
||||
assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same."
|
||||
|
||||
self._order = order
|
||||
self._order_date = pd.Timestamp(order.start_time.date())
|
||||
self._trade_range = TradeRangeByTime(order.start_time.time(), order.end_time.time())
|
||||
self._qlib_config = qlib_config
|
||||
self._inner_executor_fn = inner_executor_fn
|
||||
self._exchange_config = exchange_config
|
||||
|
||||
self._time_per_step = time_per_step
|
||||
self._ticks_per_step = int(pd.Timedelta(time_per_step).total_seconds() // 60)
|
||||
|
||||
self._executor: Optional[NestedExecutor] = None
|
||||
self._collect_data_loop: Optional[Generator] = None
|
||||
|
||||
self._done = False
|
||||
|
||||
self._inner_strategy = DecomposedStrategy()
|
||||
|
||||
self.reset(self._order)
|
||||
|
||||
def reset(self, order: Order) -> None:
|
||||
instrument = order.stock_id
|
||||
|
||||
# TODO: Check this logic. Make sure we need to do this every time we reset the simulator.
|
||||
init_qlib(self._qlib_config, instrument)
|
||||
|
||||
common_infra = get_common_infra(
|
||||
self._exchange_config,
|
||||
trade_date=pd.Timestamp(self._order_date),
|
||||
codes=[instrument],
|
||||
)
|
||||
|
||||
# TODO: We can leverage interfaces like (https://tinyurl.com/y8f8fhv4) to create trading environment.
|
||||
# TODO: By aligning the interface to create environments with Qlib, it will be easier to share the config and
|
||||
# TODO: code between backtesting and training.
|
||||
self._inner_executor = self._inner_executor_fn(self._time_per_step, common_infra)
|
||||
self._executor = NestedExecutor(
|
||||
time_per_step=COARSEST_GRANULARITY,
|
||||
inner_executor=self._inner_executor,
|
||||
inner_strategy=self._inner_strategy,
|
||||
track_data=True,
|
||||
common_infra=common_infra,
|
||||
)
|
||||
|
||||
exchange = self._inner_executor.trade_exchange
|
||||
self._ticks_index = pd.DatetimeIndex([e[1] for e in list(exchange.quote_df.index)])
|
||||
self._ticks_for_order = get_ticks_slice(
|
||||
self._ticks_index,
|
||||
self._order.start_time,
|
||||
self._order.end_time,
|
||||
include_end=True,
|
||||
)
|
||||
|
||||
self._backtest_data = QlibIntradayBacktestData(
|
||||
order=self._order,
|
||||
exchange=exchange,
|
||||
start_time=self._ticks_for_order[0],
|
||||
end_time=self._ticks_for_order[-1],
|
||||
)
|
||||
|
||||
self.twap_price = self._backtest_data.get_deal_price().mean()
|
||||
|
||||
top_strategy = SingleOrderStrategy(common_infra, order, self._trade_range, instrument)
|
||||
self._executor.reset(start_time=pd.Timestamp(self._order_date), end_time=pd.Timestamp(self._order_date))
|
||||
top_strategy.reset(level_infra=self._executor.get_level_infra())
|
||||
|
||||
self._collect_data_loop = self._executor.collect_data(top_strategy.generate_trade_decision(), level=0)
|
||||
assert isinstance(self._collect_data_loop, Generator)
|
||||
|
||||
self._iter_strategy(action=None)
|
||||
self._done = False
|
||||
|
||||
self._maintainer = StateMaintainer(
|
||||
order=self._order,
|
||||
time_per_step=self._time_per_step,
|
||||
tick_index=self._ticks_index,
|
||||
twap_price=self.twap_price,
|
||||
)
|
||||
|
||||
def _iter_strategy(self, action: float = None) -> DecomposedStrategy:
|
||||
"""Iterate the _collect_data_loop until we get the next yield DecomposedStrategy."""
|
||||
assert self._collect_data_loop is not None
|
||||
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
while not isinstance(strategy, DecomposedStrategy):
|
||||
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
|
||||
assert isinstance(strategy, DecomposedStrategy)
|
||||
return strategy
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
"""Execute one step or SAOE.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action (float):
|
||||
The amount you wish to deal. The simulator doesn't guarantee all the amount to be successfully dealt.
|
||||
"""
|
||||
|
||||
assert not self._done, "Simulator has already done!"
|
||||
|
||||
try:
|
||||
self._iter_strategy(action=action)
|
||||
except StopIteration:
|
||||
self._done = True
|
||||
|
||||
assert self._executor is not None
|
||||
_, all_indicators = get_portfolio_and_indicator(self._executor)
|
||||
|
||||
self._maintainer.update(
|
||||
inner_executor=self._inner_executor,
|
||||
inner_strategy=self._inner_strategy,
|
||||
done=self._done,
|
||||
all_indicators=all_indicators,
|
||||
)
|
||||
|
||||
def get_state(self) -> SAOEState:
|
||||
return SAOEState(
|
||||
order=self._order,
|
||||
cur_time=self._inner_executor.trade_calendar.get_step_time()[0],
|
||||
position=self._maintainer.position,
|
||||
history_exec=self._maintainer.history_exec,
|
||||
history_steps=self._maintainer.history_steps,
|
||||
metrics=self._maintainer.metrics,
|
||||
backtest_data=self._backtest_data,
|
||||
ticks_per_step=self._ticks_per_step,
|
||||
ticks_index=self._ticks_index,
|
||||
ticks_for_order=self._ticks_for_order,
|
||||
)
|
||||
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
|
|
@ -4,18 +4,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Any, TypeVar, cast
|
||||
from typing import Any, NamedTuple, Optional, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.constant import EPS
|
||||
from qlib.rl.data.pickle_styled import DealPriceType, IntradayBacktestData, load_simple_intraday_backtest_data
|
||||
from qlib.rl.simulator import Simulator
|
||||
from qlib.rl.data.pickle_styled import IntradayBacktestData, load_intraday_backtest_data, DealPriceType
|
||||
from qlib.rl.utils import LogLevel
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
# TODO: Integrating Qlib's native data with simulator_simple
|
||||
|
||||
__all__ = ["SAOEMetrics", "SAOEState", "SingleAssetOrderExecution"]
|
||||
|
||||
ONE_SEC = pd.Timedelta("1s") # use 1 second to exclude the right interval point
|
||||
|
@ -33,40 +35,40 @@ class SAOEMetrics(TypedDict):
|
|||
|
||||
stock_id: str
|
||||
"""Stock ID of this record."""
|
||||
datetime: pd.Timestamp
|
||||
datetime: pd.Timestamp | pd.DatetimeIndex # TODO: check this
|
||||
"""Datetime of this record (this is index in the dataframe)."""
|
||||
direction: int
|
||||
"""Direction of the order. 0 for sell, 1 for buy."""
|
||||
|
||||
# Market information.
|
||||
market_volume: float
|
||||
market_volume: np.ndarray | float
|
||||
"""(total) market volume traded in the period."""
|
||||
market_price: float
|
||||
market_price: np.ndarray | float
|
||||
"""Deal price. If it's a period of time, this is the average market deal price."""
|
||||
|
||||
# Strategy records.
|
||||
|
||||
amount: float
|
||||
amount: np.ndarray | float
|
||||
"""Total amount (volume) strategy intends to trade."""
|
||||
inner_amount: float
|
||||
inner_amount: np.ndarray | float
|
||||
"""Total amount that the lower-level strategy intends to trade
|
||||
(might be larger than amount, e.g., to ensure ffr)."""
|
||||
|
||||
deal_amount: float
|
||||
deal_amount: np.ndarray | float
|
||||
"""Amount that successfully takes effect (must be less than inner_amount)."""
|
||||
trade_price: float
|
||||
trade_price: np.ndarray | float
|
||||
"""The average deal price for this strategy."""
|
||||
trade_value: float
|
||||
"""Total worth of trading. In the simple simulaton, trade_value = deal_amount * price."""
|
||||
position: float
|
||||
trade_value: np.ndarray | float
|
||||
"""Total worth of trading. In the simple simulation, trade_value = deal_amount * price."""
|
||||
position: np.ndarray | float
|
||||
"""Position left after this "period"."""
|
||||
|
||||
# Accumulated metrics
|
||||
|
||||
ffr: float
|
||||
ffr: np.ndarray | float
|
||||
"""Completed how much percent of the daily order."""
|
||||
|
||||
pa: float
|
||||
pa: np.ndarray | float
|
||||
"""Price advantage compared to baseline (i.e., trade with baseline market price).
|
||||
The baseline is trade price when using TWAP strategy to execute this order.
|
||||
Please note that there could be data leak here).
|
||||
|
@ -87,7 +89,7 @@ class SAOEState(NamedTuple):
|
|||
history_steps: pd.DataFrame
|
||||
"""See :attr:`SingleAssetOrderExecution.history_steps`."""
|
||||
|
||||
metrics: SAOEMetrics | None
|
||||
metrics: Optional[SAOEMetrics]
|
||||
"""Daily metric, only available when the trading is in "done" state."""
|
||||
|
||||
backtest_data: IntradayBacktestData
|
||||
|
@ -114,13 +116,13 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
|||
If such fine granularity is not needed, use ``ticks_per_step`` to
|
||||
lengthen the ticks for each step.
|
||||
|
||||
In each step, the traded amount are "equally" splitted to each tick,
|
||||
then bounded by volume maximum exeuction volume (i.e., ``vol_threshold``),
|
||||
In each step, the traded amount are "equally" separated to each tick,
|
||||
then bounded by volume maximum execution volume (i.e., ``vol_threshold``),
|
||||
and if it's the last step, try to ensure all the amount to be executed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
initial
|
||||
order
|
||||
The seed to start an SAOE simulator is an order.
|
||||
ticks_per_step
|
||||
How many ticks per step.
|
||||
|
@ -140,7 +142,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
|||
See :class:`SAOEMetrics` for available columns.
|
||||
Index is ``datetime``, which is the **starting** time of each step."""
|
||||
|
||||
metrics: SAOEMetrics | None
|
||||
metrics: Optional[SAOEMetrics]
|
||||
"""Metrics. Only available when done."""
|
||||
|
||||
twap_price: float
|
||||
|
@ -159,15 +161,21 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
|||
data_dir: Path,
|
||||
ticks_per_step: int = 30,
|
||||
deal_price_type: DealPriceType = "close",
|
||||
vol_threshold: float | None = None,
|
||||
vol_threshold: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(initial=order)
|
||||
|
||||
self.order = order
|
||||
self.ticks_per_step: int = ticks_per_step
|
||||
self.deal_price_type = deal_price_type
|
||||
self.vol_threshold = vol_threshold
|
||||
self.data_dir = data_dir
|
||||
self.backtest_data = load_intraday_backtest_data(
|
||||
self.data_dir, order.stock_id, pd.Timestamp(order.start_time.date()), self.deal_price_type, order.direction
|
||||
self.backtest_data = load_simple_intraday_backtest_data(
|
||||
self.data_dir,
|
||||
order.stock_id,
|
||||
pd.Timestamp(order.start_time.date()),
|
||||
self.deal_price_type,
|
||||
order.direction,
|
||||
)
|
||||
|
||||
self.ticks_index = self.backtest_data.get_time_index()
|
||||
|
@ -188,9 +196,9 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
|||
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
|
||||
self.metrics = None
|
||||
|
||||
self.market_price: np.ndarray | None = None
|
||||
self.market_vol: np.ndarray | None = None
|
||||
self.market_vol_limit: np.ndarray | None = None
|
||||
self.market_price: Optional[np.ndarray] = None
|
||||
self.market_vol: Optional[np.ndarray] = None
|
||||
self.market_vol_limit: Optional[np.ndarray] = None
|
||||
|
||||
def step(self, amount: float) -> None:
|
||||
"""Execute one step or SAOE.
|
||||
|
@ -205,7 +213,8 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
|||
|
||||
self.market_price = self.market_vol = None # avoid misuse
|
||||
exec_vol = self._split_exec_vol(amount)
|
||||
assert self.market_price is not None and self.market_vol is not None
|
||||
assert self.market_price is not None
|
||||
assert self.market_vol is not None
|
||||
|
||||
ticks_position = self.position - np.cumsum(exec_vol)
|
||||
|
||||
|
@ -363,7 +372,7 @@ class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]):
|
|||
inner_amount=exec_vol.sum(),
|
||||
deal_amount=exec_vol.sum(), # in this simulator, there's no other restrictions
|
||||
trade_price=exec_avg_price,
|
||||
trade_value=np.sum(market_price * exec_vol),
|
||||
trade_value=float(np.sum(market_price * exec_vol)),
|
||||
position=self.position,
|
||||
ffr=float(exec_vol.sum() / self.order.amount),
|
||||
pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction),
|
||||
|
@ -386,7 +395,9 @@ _float_or_ndarray = TypeVar("_float_or_ndarray", float, np.ndarray)
|
|||
|
||||
|
||||
def price_advantage(
|
||||
exec_price: _float_or_ndarray, baseline_price: float, direction: OrderDir | int
|
||||
exec_price: _float_or_ndarray,
|
||||
baseline_price: float,
|
||||
direction: OrderDir | int,
|
||||
) -> _float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest import CommonInfrastructure, get_exchange
|
||||
from qlib.backtest.account import Account
|
||||
from qlib.backtest.decision import OrderDir
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from qlib.rl.from_neutrader.config import ExchangeConfig
|
||||
from qlib.rl.order_execution.simulator_simple import ONE_SEC, _float_or_ndarray
|
||||
from qlib.utils.time import Freq
|
||||
|
||||
|
||||
def get_common_infra(
|
||||
config: ExchangeConfig,
|
||||
trade_date: pd.Timestamp,
|
||||
codes: List[str],
|
||||
cash_limit: float = None,
|
||||
) -> CommonInfrastructure:
|
||||
# need to specify a range here for acceleration
|
||||
if cash_limit is None:
|
||||
trade_account = Account(init_cash=int(1e12), benchmark_config={}, pos_type="InfPosition")
|
||||
else:
|
||||
trade_account = Account(
|
||||
init_cash=cash_limit,
|
||||
benchmark_config={},
|
||||
pos_type="Position",
|
||||
position_dict={code: {"amount": 1e12, "price": 1.0} for code in codes},
|
||||
)
|
||||
|
||||
exchange = get_exchange(
|
||||
codes=codes,
|
||||
freq="1min",
|
||||
limit_threshold=config.limit_threshold,
|
||||
deal_price=config.deal_price,
|
||||
open_cost=config.open_cost,
|
||||
close_cost=config.close_cost,
|
||||
min_cost=config.min_cost if config.trade_unit is not None else 0,
|
||||
start_time=trade_date,
|
||||
end_time=trade_date + pd.DateOffset(1),
|
||||
trade_unit=config.trade_unit,
|
||||
volume_threshold=config.volume_threshold,
|
||||
)
|
||||
|
||||
return CommonInfrastructure(trade_account=trade_account, trade_exchange=exchange)
|
||||
|
||||
|
||||
def get_ticks_slice(
|
||||
ticks_index: pd.DatetimeIndex,
|
||||
start: pd.Timestamp,
|
||||
end: pd.Timestamp,
|
||||
include_end: bool = False,
|
||||
) -> pd.DatetimeIndex:
|
||||
if not include_end:
|
||||
end = end - ONE_SEC
|
||||
return ticks_index[ticks_index.slice_indexer(start, end)]
|
||||
|
||||
|
||||
def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame:
|
||||
# dataframe.append is deprecated
|
||||
other_df = pd.DataFrame(other).set_index("datetime")
|
||||
other_df.index.name = "datetime"
|
||||
|
||||
res = pd.concat([df, other_df], axis=0)
|
||||
return res
|
||||
|
||||
|
||||
def price_advantage(
|
||||
exec_price: _float_or_ndarray,
|
||||
baseline_price: float,
|
||||
direction: OrderDir | int,
|
||||
) -> _float_or_ndarray:
|
||||
if baseline_price == 0: # something is wrong with data. Should be nan here
|
||||
if isinstance(exec_price, float):
|
||||
return 0.0
|
||||
else:
|
||||
return np.zeros_like(exec_price)
|
||||
if direction == OrderDir.BUY:
|
||||
res = (1 - exec_price / baseline_price) * 10000
|
||||
elif direction == OrderDir.SELL:
|
||||
res = (exec_price / baseline_price - 1) * 10000
|
||||
else:
|
||||
raise ValueError(f"Unexpected order direction: {direction}")
|
||||
res_wo_nan: np.ndarray = np.nan_to_num(res, nan=0.0)
|
||||
if res_wo_nan.size == 1:
|
||||
return res_wo_nan.item()
|
||||
else:
|
||||
return cast(_float_or_ndarray, res_wo_nan)
|
||||
|
||||
|
||||
def get_portfolio_and_indicator(executor: BaseExecutor) -> Tuple[dict, dict]:
|
||||
all_executors = executor.get_all_executors()
|
||||
all_portfolio_metrics = {
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.trade_account.get_portfolio_metrics()
|
||||
for _executor in all_executors
|
||||
if _executor.trade_account.is_port_metr_enabled()
|
||||
}
|
||||
|
||||
all_indicators = {}
|
||||
for _executor in all_executors:
|
||||
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
|
||||
all_indicators[key] = _executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
all_indicators[key + "_obj"] = _executor.trade_account.get_trade_indicator()
|
||||
|
||||
return all_portfolio_metrics, all_indicators
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generic, Any, TypeVar, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, TypeVar
|
||||
|
||||
from qlib.typehint import final
|
||||
|
||||
|
@ -20,7 +20,7 @@ class Reward(Generic[SimulatorState]):
|
|||
Subclass should implement ``reward(simulator_state)`` to implement their own reward calculation recipe.
|
||||
"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
@final
|
||||
def __call__(self, simulator_state: SimulatorState) -> float:
|
||||
|
@ -30,14 +30,15 @@ class Reward(Generic[SimulatorState]):
|
|||
"""Implement this method for your own reward."""
|
||||
raise NotImplementedError("Implement reward calculation recipe in `reward()`.")
|
||||
|
||||
def log(self, name, value):
|
||||
def log(self, name: str, value: Any) -> None:
|
||||
assert self.env is not None
|
||||
self.env.logger.add_scalar(name, value)
|
||||
|
||||
|
||||
class RewardCombination(Reward):
|
||||
"""Combination of multiple reward."""
|
||||
|
||||
def __init__(self, rewards: dict[str, tuple[Reward, float]]):
|
||||
def __init__(self, rewards: Dict[str, Tuple[Reward, float]]) -> None:
|
||||
self.rewards = rewards
|
||||
|
||||
def reward(self, simulator_state: Any) -> float:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeVar, Generic, Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
from .seed import InitialStateType
|
||||
|
||||
|
@ -49,7 +49,7 @@ class Simulator(Generic[InitialStateType, StateType, ActType]):
|
|||
Simulators are discouraged to use this, because it's prone to induce errors.
|
||||
"""
|
||||
|
||||
env: EnvWrapper | None = None
|
||||
env: Optional[EnvWrapper] = None
|
||||
|
||||
def __init__(self, initial: InitialStateType, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
|
|
@ -3,17 +3,17 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Sequence, cast, Any
|
||||
from typing import Any, Callable, Sequence, cast
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.utils import FiniteEnvType, LogWriter
|
||||
|
||||
from .vessel import TrainingVessel
|
||||
from .trainer import Trainer
|
||||
from .vessel import TrainingVessel
|
||||
|
||||
|
||||
def train(
|
||||
|
|
|
@ -12,7 +12,7 @@ import shutil
|
|||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
|
@ -6,13 +6,13 @@ from __future__ import annotations
|
|||
import copy
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, TypeVar, Sequence, cast
|
||||
from typing import Any, Iterable, Sequence, TypeVar, cast
|
||||
|
||||
import torch
|
||||
|
||||
from qlib.rl.simulator import InitialStateType
|
||||
from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogCollector, LogWriter, LogBuffer, vectorize_env, LogLevel
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.simulator import InitialStateType
|
||||
from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogBuffer, LogCollector, LogLevel, LogWriter, vectorize_env
|
||||
from qlib.rl.utils.finite_env import FiniteVectorEnv
|
||||
from qlib.typehint import Literal
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import Callable, ContextManager, Generic, Iterable, TYPE_CHECKING, Sequence, Any, TypeVar, cast, Dict
|
||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
|
@ -12,12 +12,11 @@ from tianshou.env import BaseVectorEnv
|
|||
from tianshou.policy import BasePolicy
|
||||
|
||||
from qlib.constant import INF
|
||||
from qlib.rl.interpreter import StateType, ActType, ObsType, PolicyActType
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.utils import DataQueue
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.rl.interpreter import ActionInterpreter, ActType, ObsType, PolicyActType, StateInterpreter, StateType
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import InitialStateType, Simulator
|
||||
from qlib.rl.utils import DataQueue
|
||||
from qlib.rl.utils.finite_env import FiniteVectorEnv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -209,6 +208,9 @@ class TrainingVessel(TrainingVesselBase):
|
|||
order = np.random.permutation(len(collection))
|
||||
res = [collection[o] for o in order[:size]]
|
||||
_logger.info(
|
||||
"Fast running in development mode. Cut %s initial states from %d to %d.", name, len(collection), len(res)
|
||||
"Fast running in development mode. Cut %s initial states from %d to %d.",
|
||||
name,
|
||||
len(collection),
|
||||
len(res),
|
||||
)
|
||||
return res
|
||||
|
|
|
@ -1,7 +1,21 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .data_queue import *
|
||||
from .env_wrapper import *
|
||||
from .finite_env import *
|
||||
from .log import *
|
||||
from .data_queue import DataQueue
|
||||
from .env_wrapper import EnvWrapper, EnvWrapperStatus
|
||||
from .finite_env import FiniteEnvType, vectorize_env
|
||||
from .log import ConsoleWriter, CsvWriter, LogBuffer, LogCollector, LogLevel, LogWriter
|
||||
|
||||
__all__ = [
|
||||
"LogLevel",
|
||||
"DataQueue",
|
||||
"EnvWrapper",
|
||||
"FiniteEnvType",
|
||||
"LogCollector",
|
||||
"LogWriter",
|
||||
"vectorize_env",
|
||||
"ConsoleWriter",
|
||||
"CsvWriter",
|
||||
"EnvWrapperStatus",
|
||||
"LogBuffer",
|
||||
]
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from queue import Empty
|
||||
from typing import TypeVar, Generic, Sequence, cast
|
||||
from typing import Any, Generator, Generic, Sequence, TypeVar, cast
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
@ -60,7 +62,7 @@ class DataQueue(Generic[T]):
|
|||
shuffle: bool = True,
|
||||
producer_num_workers: int = 0,
|
||||
queue_maxsize: int = 0,
|
||||
):
|
||||
) -> None:
|
||||
if queue_maxsize == 0:
|
||||
if os.cpu_count() is not None:
|
||||
queue_maxsize = cast(int, os.cpu_count())
|
||||
|
@ -78,14 +80,14 @@ class DataQueue(Generic[T]):
|
|||
self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize)
|
||||
self._done = multiprocessing.Value("i", 0)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> DataQueue:
|
||||
self.activate()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
def cleanup(self) -> None:
|
||||
with self._done.get_lock():
|
||||
self._done.value += 1
|
||||
for repeat in range(500):
|
||||
|
@ -105,7 +107,7 @@ class DataQueue(Generic[T]):
|
|||
break
|
||||
_logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}")
|
||||
|
||||
def get(self, block=True):
|
||||
def get(self, block: bool = True) -> Any:
|
||||
if not hasattr(self, "_first_get"):
|
||||
self._first_get = True
|
||||
if self._first_get:
|
||||
|
@ -120,17 +122,17 @@ class DataQueue(Generic[T]):
|
|||
if self._done.value:
|
||||
raise StopIteration # pylint: disable=raise-missing-from
|
||||
|
||||
def put(self, obj, block=True, timeout=None):
|
||||
return self._queue.put(obj, block=block, timeout=timeout)
|
||||
def put(self, obj: Any, block: bool = True, timeout: int = None) -> None:
|
||||
self._queue.put(obj, block=block, timeout=timeout)
|
||||
|
||||
def mark_as_done(self):
|
||||
def mark_as_done(self) -> None:
|
||||
with self._done.get_lock():
|
||||
self._done.value = 1
|
||||
|
||||
def done(self):
|
||||
def done(self) -> int:
|
||||
return self._done.value
|
||||
|
||||
def activate(self):
|
||||
def activate(self) -> DataQueue:
|
||||
if self._activated:
|
||||
raise ValueError("DataQueue can not activate twice.")
|
||||
thread = threading.Thread(target=self._producer, daemon=True)
|
||||
|
@ -138,20 +140,20 @@ class DataQueue(Generic[T]):
|
|||
self._activated = True
|
||||
return self
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
_logger.debug(f"__del__ of {__name__}.DataQueue")
|
||||
self.cleanup()
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Generator[Any, None, None]:
|
||||
if not self._activated:
|
||||
raise ValueError(
|
||||
"Need to call activate() to launch a daemon worker "
|
||||
"to produce data into data queue before using it. "
|
||||
"You probably have forgotten to use the DataQueue in a with block."
|
||||
"You probably have forgotten to use the DataQueue in a with block.",
|
||||
)
|
||||
return self._consumer()
|
||||
|
||||
def _consumer(self):
|
||||
def _consumer(self) -> Generator[Any, None, None]:
|
||||
while True:
|
||||
try:
|
||||
yield self.get()
|
||||
|
@ -159,7 +161,7 @@ class DataQueue(Generic[T]):
|
|||
_logger.debug("Data consumer timed-out from get.")
|
||||
return
|
||||
|
||||
def _producer(self):
|
||||
def _producer(self) -> None:
|
||||
# pytorch dataloader is used here only because we need its sampler and multi-processing
|
||||
from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel
|
||||
|
||||
|
|
|
@ -4,14 +4,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from typing import Callable, Any, Iterable, Iterator, Generic, cast
|
||||
from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast
|
||||
|
||||
import gym
|
||||
from gym import Space
|
||||
|
||||
from qlib.rl.aux_info import AuxiliaryInfoCollector
|
||||
from qlib.rl.simulator import Simulator, InitialStateType, StateType, ActType
|
||||
from qlib.rl.interpreter import StateInterpreter, ActionInterpreter, PolicyActType, ObsType
|
||||
from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, StateInterpreter
|
||||
from qlib.rl.reward import Reward
|
||||
from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType
|
||||
from qlib.typehint import TypedDict
|
||||
|
||||
from .finite_env import generate_nan_observation
|
||||
|
@ -28,7 +29,7 @@ class InfoDict(TypedDict):
|
|||
|
||||
aux_info: dict
|
||||
"""Any information depends on auxiliary info collector."""
|
||||
log: dict[str, Any]
|
||||
log: Dict[str, Any]
|
||||
"""Collected by LogCollector."""
|
||||
|
||||
|
||||
|
@ -42,14 +43,15 @@ class EnvWrapperStatus(TypedDict):
|
|||
|
||||
cur_step: int
|
||||
done: bool
|
||||
initial_state: Any | None
|
||||
initial_state: Optional[Any]
|
||||
obs_history: list
|
||||
action_history: list
|
||||
reward_history: list
|
||||
|
||||
|
||||
class EnvWrapper(
|
||||
gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]
|
||||
gym.Env[ObsType, PolicyActType],
|
||||
Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType],
|
||||
):
|
||||
"""Qlib-based RL environment, subclassing ``gym.Env``.
|
||||
A wrapper of components, including simulator, state-interpreter, action-interpreter, reward.
|
||||
|
@ -97,11 +99,11 @@ class EnvWrapper(
|
|||
simulator_fn: Callable[..., Simulator[InitialStateType, StateType, ActType]],
|
||||
state_interpreter: StateInterpreter[StateType, ObsType],
|
||||
action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType],
|
||||
seed_iterator: Iterable[InitialStateType] | None,
|
||||
reward_fn: Reward | None = None,
|
||||
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] | None = None,
|
||||
logger: LogCollector | None = None,
|
||||
):
|
||||
seed_iterator: Optional[Iterable[InitialStateType]],
|
||||
reward_fn: Reward = None,
|
||||
aux_info_collector: AuxiliaryInfoCollector[StateType, Any] = None,
|
||||
logger: LogCollector = None,
|
||||
) -> None:
|
||||
# Assign weak reference to wrapper.
|
||||
#
|
||||
# Use weak reference here, because:
|
||||
|
@ -135,11 +137,11 @@ class EnvWrapper(
|
|||
self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
def action_space(self) -> Space:
|
||||
return self.action_interpreter.action_space
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
def observation_space(self) -> Space:
|
||||
return self.state_interpreter.observation_space
|
||||
|
||||
def reset(self, **kwargs: Any) -> ObsType:
|
||||
|
@ -191,7 +193,7 @@ class EnvWrapper(
|
|||
self.seed_iterator = None
|
||||
return generate_nan_observation(self.observation_space)
|
||||
|
||||
def step(self, policy_action: PolicyActType, **kwargs: Any) -> tuple[ObsType, float, bool, InfoDict]:
|
||||
def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, float, bool, InfoDict]:
|
||||
"""Environment step.
|
||||
|
||||
See the code along with comments to get a sequence of things happening here.
|
||||
|
@ -245,5 +247,5 @@ class EnvWrapper(
|
|||
info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info)
|
||||
return obs, rew, done, info_dict
|
||||
|
||||
def render(self):
|
||||
def render(self, mode: str = "human") -> None:
|
||||
raise NotImplementedError("Render is not implemented in EnvWrapper.")
|
||||
|
|
|
@ -11,11 +11,10 @@ from __future__ import annotations
|
|||
import copy
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import Any, Set, Callable, Type
|
||||
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv
|
||||
|
||||
from qlib.typehint import Literal
|
||||
|
@ -32,11 +31,11 @@ __all__ = [
|
|||
"vectorize_env",
|
||||
]
|
||||
|
||||
|
||||
FiniteEnvType = Literal["dummy", "subproc", "shmem"]
|
||||
T = Union[dict, list, tuple, np.ndarray]
|
||||
|
||||
|
||||
def fill_invalid(obj):
|
||||
def fill_invalid(obj: int | float | bool | T) -> T:
|
||||
if isinstance(obj, (int, float, bool)):
|
||||
return fill_invalid(np.array(obj))
|
||||
if hasattr(obj, "dtype"):
|
||||
|
@ -55,11 +54,11 @@ def fill_invalid(obj):
|
|||
raise ValueError(f"Unsupported value to fill with invalid: {obj}")
|
||||
|
||||
|
||||
def is_invalid(arr):
|
||||
if hasattr(arr, "dtype"):
|
||||
def is_invalid(arr: int | float | bool | T) -> bool:
|
||||
if isinstance(arr, np.ndarray):
|
||||
if np.issubdtype(arr.dtype, np.floating):
|
||||
return np.isnan(arr).all()
|
||||
return (np.iinfo(arr.dtype).max == arr).all()
|
||||
return cast(bool, cast(np.ndarray, np.iinfo(arr.dtype).max == arr).all())
|
||||
if isinstance(arr, dict):
|
||||
return all(is_invalid(o) for o in arr.values())
|
||||
if isinstance(arr, (list, tuple)):
|
||||
|
@ -140,44 +139,44 @@ class FiniteVectorEnv(BaseVectorEnv):
|
|||
|
||||
self._collector_guarded: bool = False
|
||||
|
||||
def _reset_alive_envs(self):
|
||||
def _reset_alive_envs(self) -> None:
|
||||
if not self._alive_env_ids:
|
||||
# starting or running out
|
||||
self._alive_env_ids = set(range(self.env_num))
|
||||
|
||||
# to workaround with tianshou's buffer and batch
|
||||
def _set_default_obs(self, obs):
|
||||
def _set_default_obs(self, obs: Any) -> None:
|
||||
if obs is not None and self._default_obs is None:
|
||||
self._default_obs = copy.deepcopy(obs)
|
||||
|
||||
def _set_default_info(self, info):
|
||||
def _set_default_info(self, info: Any) -> None:
|
||||
if info is not None and self._default_info is None:
|
||||
self._default_info = copy.deepcopy(info)
|
||||
|
||||
def _set_default_rew(self, rew):
|
||||
def _set_default_rew(self, rew: Any) -> None:
|
||||
if rew is not None and self._default_rew is None:
|
||||
self._default_rew = copy.deepcopy(rew)
|
||||
|
||||
def _get_default_obs(self):
|
||||
def _get_default_obs(self) -> Any:
|
||||
return copy.deepcopy(self._default_obs)
|
||||
|
||||
def _get_default_info(self):
|
||||
def _get_default_info(self) -> Any:
|
||||
return copy.deepcopy(self._default_info)
|
||||
|
||||
def _get_default_rew(self):
|
||||
def _get_default_rew(self) -> Any:
|
||||
return copy.deepcopy(self._default_rew)
|
||||
|
||||
# END
|
||||
|
||||
@staticmethod
|
||||
def _postproc_env_obs(obs):
|
||||
def _postproc_env_obs(obs: Any) -> Optional[Any]:
|
||||
# reserved for shmem vector env to restore empty observation
|
||||
if obs is None or check_nan_observation(obs):
|
||||
return None
|
||||
return obs
|
||||
|
||||
@contextmanager
|
||||
def collector_guard(self):
|
||||
def collector_guard(self) -> Generator[FiniteVectorEnv, None, None]:
|
||||
"""Guard the collector. Recommended to guard every collect.
|
||||
|
||||
This guard is for two purposes.
|
||||
|
@ -207,7 +206,10 @@ class FiniteVectorEnv(BaseVectorEnv):
|
|||
for logger in self._logger:
|
||||
logger.on_env_all_done()
|
||||
|
||||
def reset(self, id=None):
|
||||
def reset(
|
||||
self,
|
||||
id: int | List[int] | np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
assert not self._zombie
|
||||
|
||||
# Check whether it's guarded by collector_guard()
|
||||
|
@ -219,23 +221,23 @@ class FiniteVectorEnv(BaseVectorEnv):
|
|||
RuntimeWarning,
|
||||
)
|
||||
|
||||
id = self._wrap_id(id)
|
||||
wrapped_id = self._wrap_id(id)
|
||||
self._reset_alive_envs()
|
||||
|
||||
# ask super to reset alive envs and remap to current index
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
||||
obs = [None] * len(id)
|
||||
id2idx = {i: k for k, i in enumerate(id)}
|
||||
request_id = [i for i in wrapped_id if i in self._alive_env_ids]
|
||||
obs = [None] * len(wrapped_id)
|
||||
id2idx = {i: k for k, i in enumerate(wrapped_id)}
|
||||
if request_id:
|
||||
for i, o in zip(request_id, super().reset(request_id)):
|
||||
obs[id2idx[i]] = self._postproc_env_obs(o)
|
||||
|
||||
for i, o in zip(id, obs):
|
||||
for i, o in zip(wrapped_id, obs):
|
||||
if o is None and i in self._alive_env_ids:
|
||||
self._alive_env_ids.remove(i)
|
||||
|
||||
# logging
|
||||
for i, o in zip(id, obs):
|
||||
for i, o in zip(wrapped_id, obs):
|
||||
if i in self._alive_env_ids:
|
||||
for logger in self._logger:
|
||||
logger.on_env_reset(i, obs)
|
||||
|
@ -248,19 +250,23 @@ class FiniteVectorEnv(BaseVectorEnv):
|
|||
obs[i] = self._get_default_obs()
|
||||
|
||||
if not self._alive_env_ids:
|
||||
# comment this line so that the env becomes indisposable
|
||||
# comment this line so that the env becomes indispensable
|
||||
# self.reset()
|
||||
self._zombie = True
|
||||
raise StopIteration
|
||||
|
||||
return np.stack(obs)
|
||||
|
||||
def step(self, action, id=None):
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
id: int | List[int] | np.ndarray | None = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert not self._zombie
|
||||
id = self._wrap_id(id)
|
||||
id2idx = {i: k for k, i in enumerate(id)}
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
||||
result = [[None, None, False, None] for _ in range(len(id))]
|
||||
wrapped_id = self._wrap_id(id)
|
||||
id2idx = {i: k for k, i in enumerate(wrapped_id)}
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id))
|
||||
result = [[None, None, False, None] for _ in range(len(wrapped_id))]
|
||||
|
||||
# ask super to step alive envs and remap to current index
|
||||
if request_id:
|
||||
|
@ -270,7 +276,7 @@ class FiniteVectorEnv(BaseVectorEnv):
|
|||
result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0])
|
||||
|
||||
# logging
|
||||
for i, r in zip(id, result):
|
||||
for i, r in zip(wrapped_id, result):
|
||||
if i in self._alive_env_ids:
|
||||
for logger in self._logger:
|
||||
logger.on_env_step(i, *r)
|
||||
|
@ -287,7 +293,8 @@ class FiniteVectorEnv(BaseVectorEnv):
|
|||
if r[3] is None:
|
||||
result[i][3] = self._get_default_info()
|
||||
|
||||
return list(map(np.stack, zip(*result)))
|
||||
ret = list(map(np.stack, zip(*result)))
|
||||
return cast(Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], ret)
|
||||
|
||||
|
||||
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
|
||||
|
@ -306,7 +313,7 @@ def vectorize_env(
|
|||
env_factory: Callable[..., gym.Env],
|
||||
env_type: FiniteEnvType,
|
||||
concurrency: int,
|
||||
logger: LogWriter | list[LogWriter],
|
||||
logger: LogWriter | List[LogWriter],
|
||||
) -> FiniteVectorEnv:
|
||||
"""Helper function to create a vector env. Can be used to replace usual VectorEnv.
|
||||
|
||||
|
@ -350,7 +357,7 @@ def vectorize_env(
|
|||
def env_factory(): ...
|
||||
vectorize_env(env_factory, ...)
|
||||
"""
|
||||
env_type_cls_mapping: dict[str, Type[FiniteVectorEnv]] = {
|
||||
env_type_cls_mapping: Dict[str, Type[FiniteVectorEnv]] = {
|
||||
"dummy": FiniteDummyVectorEnv,
|
||||
"subproc": FiniteSubprocVectorEnv,
|
||||
"shmem": FiniteShmemVectorEnv,
|
||||
|
|
|
@ -21,7 +21,7 @@ import logging
|
|||
from collections import defaultdict
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, Generic, Set, TYPE_CHECKING, Sequence, Callable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, List, Sequence, Set, Tuple, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -65,13 +65,13 @@ class LogCollector:
|
|||
``min_loglevel`` is for optimization purposes: to avoid too much traffic on networks / in pipe.
|
||||
"""
|
||||
|
||||
_logged: dict[str, tuple[int, Any]]
|
||||
_logged: Dict[str, Tuple[int, Any]]
|
||||
_min_loglevel: int
|
||||
|
||||
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
def __init__(self, min_loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
self._min_loglevel = int(min_loglevel)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Clear all collected contents."""
|
||||
self._logged = {}
|
||||
|
||||
|
@ -104,7 +104,10 @@ class LogCollector:
|
|||
self._add_metric(name, scalar, loglevel)
|
||||
|
||||
def add_array(
|
||||
self, name: str, array: np.ndarray | pd.DataFrame | pd.Series, loglevel: int | LogLevel = LogLevel.PERIODIC
|
||||
self,
|
||||
name: str,
|
||||
array: np.ndarray | pd.DataFrame | pd.Series,
|
||||
loglevel: int | LogLevel = LogLevel.PERIODIC,
|
||||
) -> None:
|
||||
"""Add an array with name into logging."""
|
||||
if loglevel < self._min_loglevel:
|
||||
|
@ -127,7 +130,7 @@ class LogCollector:
|
|||
|
||||
self._add_metric(name, obj, loglevel)
|
||||
|
||||
def logs(self) -> dict[str, np.ndarray]:
|
||||
def logs(self) -> Dict[str, np.ndarray]:
|
||||
return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()}
|
||||
|
||||
|
||||
|
@ -154,16 +157,16 @@ class LogWriter(Generic[ObsType, ActType]):
|
|||
active_env_ids: Set[int]
|
||||
"""Active environment ids in vector env."""
|
||||
|
||||
episode_lengths: dict[int, int]
|
||||
episode_lengths: Dict[int, int]
|
||||
"""Map from environment id to episode length."""
|
||||
|
||||
episode_rewards: dict[int, list[float]]
|
||||
episode_rewards: Dict[int, List[float]]
|
||||
"""Map from environment id to episode total reward."""
|
||||
|
||||
episode_logs: dict[int, list]
|
||||
episode_logs: Dict[int, list]
|
||||
"""Map from environment id to episode logs."""
|
||||
|
||||
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
def __init__(self, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
self.loglevel = loglevel
|
||||
|
||||
self.global_step = 0
|
||||
|
@ -207,11 +210,12 @@ class LogWriter(Generic[ObsType, ActType]):
|
|||
# These are runtime infos.
|
||||
# Though they are loaded, I don't think it really helps.
|
||||
self.active_env_ids = state_dict["active_env_ids"]
|
||||
self.episode_lenghts = state_dict["episode_lengths"]
|
||||
self.episode_lengths = state_dict["episode_lengths"]
|
||||
self.episode_rewards = state_dict["episode_rewards"]
|
||||
self.episode_logs = state_dict["episode_logs"]
|
||||
|
||||
def aggregation(self, array: Sequence[Any], name: str | None = None) -> Any:
|
||||
@staticmethod
|
||||
def aggregation(array: Sequence[Any], name: str | None = None) -> Any:
|
||||
"""Aggregation function from step-wise to episode-wise.
|
||||
|
||||
If it's a sequence of float, take the mean.
|
||||
|
@ -229,7 +233,7 @@ class LogWriter(Generic[ObsType, ActType]):
|
|||
else:
|
||||
return array[0]
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
|
||||
"""This is triggered at the end of each trajectory.
|
||||
|
||||
Parameters
|
||||
|
@ -242,7 +246,7 @@ class LogWriter(Generic[ObsType, ActType]):
|
|||
Logged contents for every steps.
|
||||
"""
|
||||
|
||||
def log_step(self, reward: float, contents: dict[str, Any]) -> None:
|
||||
def log_step(self, reward: float, contents: Dict[str, Any]) -> None:
|
||||
"""This is triggered at each step.
|
||||
|
||||
Parameters
|
||||
|
@ -265,7 +269,7 @@ class LogWriter(Generic[ObsType, ActType]):
|
|||
# TODO: reward can be a list of list for MARL
|
||||
self.episode_rewards[env_id].append(rew)
|
||||
|
||||
values: dict[str, Any] = {}
|
||||
values: Dict[str, Any] = {}
|
||||
|
||||
for key, (loglevel, value) in info["log"].items():
|
||||
if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME)
|
||||
|
@ -393,11 +397,11 @@ class ConsoleWriter(LogWriter):
|
|||
def __init__(
|
||||
self,
|
||||
log_every_n_episode: int = 20,
|
||||
total_episodes: int | None = None,
|
||||
total_episodes: int = None,
|
||||
float_format: str = ":.4f",
|
||||
counter_format: str = ":4d",
|
||||
loglevel: int | LogLevel = LogLevel.PERIODIC,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__(loglevel)
|
||||
# TODO: support log_every_n_step
|
||||
self.log_every_n_episode = log_every_n_episode
|
||||
|
@ -412,15 +416,15 @@ class ConsoleWriter(LogWriter):
|
|||
|
||||
# FIXME: save & reload
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
# Clear average meters
|
||||
self.metric_counts: dict[str, int] = defaultdict(int)
|
||||
self.metric_sums: dict[str, float] = defaultdict(float)
|
||||
self.metric_counts: Dict[str, int] = defaultdict(int)
|
||||
self.metric_sums: Dict[str, float] = defaultdict(float)
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
|
||||
# Aggregate step-wise to episode-wise
|
||||
episode_wise_contents: dict[str, list] = defaultdict(list)
|
||||
episode_wise_contents: Dict[str, list] = defaultdict(list)
|
||||
|
||||
for step_contents in contents:
|
||||
for name, value in step_contents.items():
|
||||
|
@ -429,7 +433,7 @@ class ConsoleWriter(LogWriter):
|
|||
|
||||
# Generate log contents and track them in average-meter.
|
||||
# This should be done at every step, regardless of periodic or not.
|
||||
logs: dict[str, float] = {}
|
||||
logs: Dict[str, float] = {}
|
||||
for name, values in episode_wise_contents.items():
|
||||
logs[name] = self.aggregation(values, name) # type: ignore
|
||||
|
||||
|
@ -441,7 +445,7 @@ class ConsoleWriter(LogWriter):
|
|||
# Only log periodically or at the end
|
||||
self.console_logger.info(self.generate_log_message(logs))
|
||||
|
||||
def generate_log_message(self, logs: dict[str, float]) -> str:
|
||||
def generate_log_message(self, logs: Dict[str, float]) -> str:
|
||||
if self.prefix:
|
||||
msg_prefix = self.prefix + " "
|
||||
else:
|
||||
|
@ -471,29 +475,29 @@ class CsvWriter(LogWriter):
|
|||
|
||||
SUPPORTED_TYPES = (float, str, pd.Timestamp)
|
||||
|
||||
all_records: list[dict[str, Any]]
|
||||
all_records: List[Dict[str, Any]]
|
||||
|
||||
# FIXME: save & reload
|
||||
|
||||
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC):
|
||||
def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None:
|
||||
super().__init__(loglevel)
|
||||
self.output_dir = output_dir
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.all_records = []
|
||||
|
||||
def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None:
|
||||
def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None:
|
||||
# FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup
|
||||
episode_wise_contents: dict[str, list] = defaultdict(list)
|
||||
episode_wise_contents: Dict[str, list] = defaultdict(list)
|
||||
|
||||
for step_contents in contents:
|
||||
for name, value in step_contents.items():
|
||||
if isinstance(value, self.SUPPORTED_TYPES):
|
||||
episode_wise_contents[name].append(value)
|
||||
|
||||
logs: dict[str, float] = {}
|
||||
logs: Dict[str, float] = {}
|
||||
for name, values in episode_wise_contents.items():
|
||||
logs[name] = self.aggregation(values, name) # type: ignore
|
||||
|
||||
|
|
|
@ -2,14 +2,14 @@
|
|||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Generator, Optional
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Generator, Optional, TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple
|
||||
|
||||
from ..backtest.decision import BaseTradeDecision
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
|
@ -207,8 +207,18 @@ class BaseStrategy:
|
|||
range_limit = self.outer_trade_decision.get_data_cal_range_limit(rtype=rtype)
|
||||
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
|
||||
|
||||
def post_exe_step(self, execute_result: list) -> None:
|
||||
"""
|
||||
A hook for doing sth after the corresponding executor finished its execution.
|
||||
|
||||
class RLStrategy(BaseStrategy):
|
||||
Parameters
|
||||
----------
|
||||
execute_result :
|
||||
the execution result
|
||||
"""
|
||||
|
||||
|
||||
class RLStrategy(BaseStrategy, metaclass=ABCMeta):
|
||||
"""RL-based strategy"""
|
||||
|
||||
def __init__(
|
||||
|
@ -229,14 +239,14 @@ class RLStrategy(BaseStrategy):
|
|||
self.policy = policy
|
||||
|
||||
|
||||
class RLIntStrategy(RLStrategy):
|
||||
class RLIntStrategy(RLStrategy, metaclass=ABCMeta):
|
||||
"""(RL)-based (Strategy) with (Int)erpreter"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy,
|
||||
state_interpreter: Union[dict, StateInterpreter],
|
||||
action_interpreter: Union[dict, ActionInterpreter],
|
||||
state_interpreter: dict | StateInterpreter,
|
||||
action_interpreter: dict | ActionInterpreter,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
|
|
|
@ -271,7 +271,7 @@ class LocIndexer:
|
|||
if isinstance(_indexing, IndexData):
|
||||
_indexing = _indexing.data
|
||||
assert _indexing.ndim == 1
|
||||
if _indexing.dtype != np.bool:
|
||||
if _indexing.dtype != bool:
|
||||
_indexing = np.array(list(index.index(i) for i in _indexing))
|
||||
else:
|
||||
_indexing = index.index(_indexing)
|
||||
|
@ -431,7 +431,7 @@ class IndexData(metaclass=index_data_ops_creator):
|
|||
|
||||
# The code below could be simpler like methods in __getattribute__
|
||||
def __invert__(self):
|
||||
return self.__class__(~self.data.astype(np.bool), *self.indices)
|
||||
return self.__class__(~self.data.astype(bool), *self.indices)
|
||||
|
||||
def abs(self):
|
||||
"""get the abs of data except np.NaN."""
|
||||
|
|
|
@ -5,6 +5,8 @@ from random import randint, choice
|
|||
from pathlib import Path
|
||||
|
||||
import re
|
||||
from typing import Any, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -24,16 +26,16 @@ from qlib.rl.utils.finite_env import vectorize_env
|
|||
|
||||
|
||||
class SimpleEnv(gym.Env[int, int]):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.logger = LogCollector()
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
def reset(self, *args: Any, **kwargs: Any) -> int:
|
||||
self.step_count = 0
|
||||
return 0
|
||||
|
||||
def step(self, action: int):
|
||||
def step(self, action: int) -> Tuple[int, float, bool, dict]:
|
||||
self.logger.reset()
|
||||
|
||||
self.logger.add_scalar("reward", 42.0)
|
||||
|
@ -53,6 +55,9 @@ class SimpleEnv(gym.Env[int, int]):
|
|||
|
||||
return 1, 42.0, done, InfoDict(log=self.logger.logs(), aux_info={})
|
||||
|
||||
def render(self, mode: str = "human") -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AnyPolicy(BasePolicy):
|
||||
def forward(self, batch, state=None):
|
||||
|
@ -86,7 +91,8 @@ def test_simple_env_logger(caplog):
|
|||
|
||||
|
||||
class SimpleSimulator(Simulator[int, float, float]):
|
||||
def __init__(self, initial: int, **kwargs) -> None:
|
||||
def __init__(self, initial: int, **kwargs: Any) -> None:
|
||||
super(SimpleSimulator, self).__init__(initial, **kwargs)
|
||||
self.initial = float(initial)
|
||||
|
||||
def step(self, action: float) -> None:
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from qlib.backtest.decision import Order, OrderDir
|
||||
from qlib.backtest.executor import NestedExecutor, SimulatorExecutor
|
||||
from qlib.backtest.utils import CommonInfrastructure
|
||||
from qlib.contrib.strategy import TWAPStrategy
|
||||
from qlib.rl.order_execution import CategoricalActionInterpreter
|
||||
from qlib.rl.order_execution.simulator_qlib import ExchangeConfig, SingleAssetOrderExecutionQlib
|
||||
|
||||
TOTAL_POSITION = 2100.0
|
||||
|
||||
python_version_request = pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher")
|
||||
|
||||
|
||||
def is_close(a: float, b: float, epsilon: float = 1e-4) -> bool:
|
||||
return abs(a - b) <= epsilon
|
||||
|
||||
|
||||
def get_order() -> Order:
|
||||
return Order(
|
||||
stock_id="SH600000",
|
||||
amount=TOTAL_POSITION,
|
||||
direction=OrderDir.BUY,
|
||||
start_time=pd.Timestamp("2019-03-04 09:30:00"),
|
||||
end_time=pd.Timestamp("2019-03-04 14:29:00"),
|
||||
)
|
||||
|
||||
|
||||
def get_simulator(order: Order) -> SingleAssetOrderExecutionQlib:
|
||||
def _inner_executor_fn(time_per_step: str, common_infra: CommonInfrastructure) -> NestedExecutor:
|
||||
return NestedExecutor(
|
||||
time_per_step=time_per_step,
|
||||
inner_strategy=TWAPStrategy(),
|
||||
inner_executor=SimulatorExecutor(
|
||||
time_per_step="1min",
|
||||
verbose=False,
|
||||
trade_type=SimulatorExecutor.TT_SERIAL,
|
||||
generate_report=False,
|
||||
common_infra=common_infra,
|
||||
track_data=True,
|
||||
),
|
||||
common_infra=common_infra,
|
||||
track_data=True,
|
||||
)
|
||||
|
||||
DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "qlib_simulator"
|
||||
|
||||
# fmt: off
|
||||
qlib_config = {
|
||||
"provider_uri_day": DATA_ROOT_DIR / "qlib_1d",
|
||||
"provider_uri_1min": DATA_ROOT_DIR / "qlib_1min",
|
||||
"feature_root_dir": DATA_ROOT_DIR / "qlib_handler_stock",
|
||||
"feature_columns_today": [
|
||||
"$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume",
|
||||
"$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5",
|
||||
],
|
||||
"feature_columns_yesterday": [
|
||||
"$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1",
|
||||
"$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1",
|
||||
],
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
exchange_config = ExchangeConfig(
|
||||
limit_threshold=("$ask == 0", "$bid == 0"),
|
||||
deal_price=("If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"),
|
||||
volume_threshold={
|
||||
"all": ("cum", "0.2 * DayCumsum($volume, '9:30', '14:29')"),
|
||||
"buy": ("current", "$askV1"),
|
||||
"sell": ("current", "$bidV1"),
|
||||
},
|
||||
open_cost=0.0005,
|
||||
close_cost=0.0015,
|
||||
min_cost=5.0,
|
||||
trade_unit=None,
|
||||
cash_limit=None,
|
||||
generate_report=False,
|
||||
)
|
||||
|
||||
return SingleAssetOrderExecutionQlib(
|
||||
order=order,
|
||||
time_per_step="30min",
|
||||
qlib_config=qlib_config,
|
||||
inner_executor_fn=_inner_executor_fn,
|
||||
exchange_config=exchange_config,
|
||||
)
|
||||
|
||||
|
||||
@python_version_request
|
||||
def test_simulator_first_step():
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp("2019-03-04 09:30:00")
|
||||
assert state.position == TOTAL_POSITION
|
||||
|
||||
AMOUNT = 300.0
|
||||
simulator.step(AMOUNT)
|
||||
state = simulator.get_state()
|
||||
assert state.cur_time == pd.Timestamp("2019-03-04 10:00:00")
|
||||
assert state.position == TOTAL_POSITION - AMOUNT
|
||||
assert len(state.history_exec) == 30
|
||||
assert state.history_exec.index[0] == pd.Timestamp("2019-03-04 09:30:00")
|
||||
|
||||
assert is_close(state.history_exec["market_volume"].iloc[0], 109382.382812)
|
||||
assert is_close(state.history_exec["market_price"].iloc[0], 149.566483)
|
||||
assert (state.history_exec["amount"] == AMOUNT / 30).all()
|
||||
assert (state.history_exec["deal_amount"] == AMOUNT / 30).all()
|
||||
assert is_close(state.history_exec["trade_price"].iloc[0], 149.566483)
|
||||
assert is_close(state.history_exec["trade_value"].iloc[0], 1495.664825)
|
||||
assert is_close(state.history_exec["position"].iloc[0], TOTAL_POSITION - AMOUNT / 30)
|
||||
# assert state.history_exec["ffr"].iloc[0] == 1 / 60 # FIXME
|
||||
|
||||
assert is_close(state.history_steps["market_volume"].iloc[0], 1254848.5756835938)
|
||||
assert state.history_steps["amount"].iloc[0] == AMOUNT
|
||||
assert state.history_steps["deal_amount"].iloc[0] == AMOUNT
|
||||
assert state.history_steps["ffr"].iloc[0] == 1.0
|
||||
assert is_close(
|
||||
state.history_steps["pa"].iloc[0] * (1.0 if order.direction == OrderDir.SELL else -1.0),
|
||||
(state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000,
|
||||
)
|
||||
|
||||
|
||||
@python_version_request
|
||||
def test_simulator_stop_twap() -> None:
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
NUM_STEPS = 7
|
||||
for i in range(NUM_STEPS):
|
||||
simulator.step(TOTAL_POSITION / NUM_STEPS)
|
||||
|
||||
HISTORY_STEP_LENGTH = 30 * NUM_STEPS
|
||||
state = simulator.get_state()
|
||||
assert len(state.history_exec) == HISTORY_STEP_LENGTH
|
||||
|
||||
assert (state.history_exec["deal_amount"] == TOTAL_POSITION / HISTORY_STEP_LENGTH).all()
|
||||
assert is_close(state.history_steps["position"].iloc[0], TOTAL_POSITION * (NUM_STEPS - 1) / NUM_STEPS)
|
||||
assert is_close(state.history_steps["position"].iloc[-1], 0.0)
|
||||
assert is_close(state.position, 0.0)
|
||||
assert is_close(state.metrics["ffr"], 1.0)
|
||||
|
||||
assert is_close(state.metrics["market_price"], state.backtest_data.get_deal_price().mean())
|
||||
assert is_close(state.metrics["market_volume"], state.backtest_data.get_volume().sum())
|
||||
assert is_close(state.metrics["trade_price"], state.metrics["market_price"])
|
||||
assert is_close(state.metrics["pa"], 0.0)
|
||||
|
||||
assert simulator.done()
|
||||
|
||||
|
||||
@python_version_request
|
||||
def test_interpreter() -> None:
|
||||
NUM_EXECUTION = 3
|
||||
order = get_order()
|
||||
simulator = get_simulator(order)
|
||||
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
|
||||
|
||||
NUM_STEPS = 7
|
||||
state = simulator.get_state()
|
||||
position_history = []
|
||||
for i in range(NUM_STEPS):
|
||||
simulator.step(interpreter_action(state, 1))
|
||||
state = simulator.get_state()
|
||||
position_history.append(state.position)
|
||||
|
||||
assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_simulator_first_step()
|
||||
test_simulator_stop_twap()
|
||||
test_interpreter()
|
|
@ -9,7 +9,6 @@ from typing import NamedTuple
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
from tianshou.data import Batch
|
||||
|
||||
|
@ -17,8 +16,8 @@ from qlib.backtest import Order
|
|||
from qlib.config import C
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.rl.data import pickle_styled
|
||||
from qlib.rl.trainer import backtest, train
|
||||
from qlib.rl.order_execution import *
|
||||
from qlib.rl.trainer import backtest, train
|
||||
from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus
|
||||
|
||||
pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8")
|
||||
|
@ -38,7 +37,7 @@ CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights"
|
|||
|
||||
|
||||
def test_pickle_data_inspect():
|
||||
data = pickle_styled.load_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
|
||||
data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0)
|
||||
assert len(data) == 390
|
||||
|
||||
data = pickle_styled.load_intraday_processed_data(
|
||||
|
|
Загрузка…
Ссылка в новой задаче