зеркало из https://github.com/microsoft/qlib.git
Use mock data for element operator tests. (#1330)
This commit is contained in:
Родитель
08de1a1874
Коммит
fb5888be9e
|
@ -1,10 +1,16 @@
|
|||
from typing import Union, List, Dict, Tuple
|
||||
import unittest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import io
|
||||
|
||||
from .data import GetData
|
||||
from .. import init
|
||||
from ..constant import REG_CN
|
||||
from ..constant import REG_CN, REG_TW
|
||||
from qlib.data.filter import NameDFilter
|
||||
from qlib.data import D
|
||||
from qlib.data.data import Cal, DatasetD
|
||||
from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT
|
||||
|
||||
|
||||
class TestAutoData(unittest.TestCase):
|
||||
|
@ -75,3 +81,211 @@ class TestOperatorData(TestAutoData):
|
|||
cls.end_time = cal[-1]
|
||||
cls.inst = list(instruments_d.keys())[0]
|
||||
cls.spans = list(instruments_d.values())[0]
|
||||
|
||||
|
||||
MOCK_DATA = """
|
||||
id,symbol,datetime,interval,volume,open,high,low,close
|
||||
20275,0050,2022-01-03 00:00:00,day,6761.0,146.0,147.35,146.0,146.4
|
||||
20276,0050,2022-01-04 00:00:00,day,9608.0,147.7,149.6,147.7,149.6
|
||||
20277,0050,2022-01-05 00:00:00,day,11387.0,150.1,150.55,149.1,149.3
|
||||
20278,0050,2022-01-06 00:00:00,day,8611.0,148.3,148.75,147.0,147.9
|
||||
20279,0050,2022-01-07 00:00:00,day,6954.0,148.3,149.0,146.5,146.6
|
||||
20280,0050,2022-01-10 00:00:00,day,15684.0,146.0,147.8,145.4,147.55
|
||||
20281,0050,2022-01-11 00:00:00,day,17741.0,147.6,148.5,146.7,148.3
|
||||
20282,0050,2022-01-12 00:00:00,day,10134.0,149.35,149.6,148.7,149.55
|
||||
20283,0050,2022-01-13 00:00:00,day,7431.0,149.55,150.45,149.55,150.3
|
||||
20284,0050,2022-01-14 00:00:00,day,10091.0,150.8,151.2,149.05,150.3
|
||||
20285,0050,2022-01-17 00:00:00,day,6899.0,151.1,152.4,151.1,152.0
|
||||
20286,0050,2022-01-18 00:00:00,day,14360.0,152.2,152.25,150.15,150.3
|
||||
20287,0050,2022-01-19 00:00:00,day,14654.0,149.0,149.65,148.25,148.5
|
||||
20288,0050,2022-01-20 00:00:00,day,16201.0,148.5,149.2,147.6,149.1
|
||||
20289,0050,2022-01-21 00:00:00,day,29848.0,143.9,143.95,142.3,142.65
|
||||
20290,0050,2022-01-24 00:00:00,day,13143.0,142.1,144.0,141.7,144.0
|
||||
20291,0050,2022-01-25 00:00:00,day,23982.0,142.55,142.55,141.25,141.65
|
||||
20292,0050,2022-01-26 00:00:00,day,17729.0,141.15,142.2,141.05,141.55
|
||||
8547,1101,2021-12-01 00:00:00,day,16119.0,46.0,46.85,46.0,46.6
|
||||
8548,1101,2021-12-02 00:00:00,day,14521.0,46.6,46.7,46.3,46.3
|
||||
8549,1101,2021-12-03 00:00:00,day,14357.0,46.55,46.85,46.4,46.4
|
||||
8550,1101,2021-12-06 00:00:00,day,15115.0,46.45,47.35,46.4,47.3
|
||||
8551,1101,2021-12-07 00:00:00,day,13117.0,47.35,47.55,46.9,47.55
|
||||
8552,1101,2021-12-08 00:00:00,day,10329.0,47.75,47.8,47.5,47.7
|
||||
8553,1101,2021-12-09 00:00:00,day,9300.0,47.8,47.85,47.1,47.4
|
||||
8554,1101,2021-12-10 00:00:00,day,9919.0,47.4,47.6,47.1,47.3
|
||||
8555,1101,2021-12-13 00:00:00,day,7784.0,47.3,47.75,47.1,47.1
|
||||
8556,1101,2021-12-14 00:00:00,day,9373.0,47.05,47.2,46.95,47.0
|
||||
8557,1101,2021-12-15 00:00:00,day,11189.0,47.0,47.3,46.8,46.95
|
||||
8558,1101,2021-12-16 00:00:00,day,7516.0,47.0,47.15,46.8,46.9
|
||||
8559,1101,2021-12-17 00:00:00,day,18502.0,46.95,47.6,46.9,47.45
|
||||
8560,1101,2021-12-20 00:00:00,day,11309.0,47.45,47.5,47.1,47.4
|
||||
8561,1101,2021-12-21 00:00:00,day,5666.0,47.4,47.45,47.1,47.25
|
||||
8562,1101,2021-12-22 00:00:00,day,5460.0,47.4,47.45,47.2,47.4
|
||||
8563,1101,2021-12-23 00:00:00,day,9371.0,47.3,47.7,47.3,47.7
|
||||
8564,1101,2021-12-24 00:00:00,day,5980.0,47.75,47.95,47.75,47.9
|
||||
8565,1101,2021-12-27 00:00:00,day,5709.0,47.9,48.1,47.9,48.1
|
||||
8566,1101,2021-12-28 00:00:00,day,7777.0,48.1,48.15,47.95,48.15
|
||||
8567,1101,2021-12-29 00:00:00,day,5309.0,48.15,48.25,48.05,48.15
|
||||
8568,1101,2021-12-30 00:00:00,day,4616.0,48.15,48.2,48.0,48.0
|
||||
8569,1101,2022-01-03 00:00:00,day,12350.0,48.05,48.15,47.35,47.45
|
||||
8570,1101,2022-01-04 00:00:00,day,11439.0,47.5,47.6,47.0,47.3
|
||||
8571,1101,2022-01-05 00:00:00,day,9692.0,47.1,47.3,47.0,47.15
|
||||
8572,1101,2022-01-06 00:00:00,day,12361.0,47.3,47.6,47.15,47.6
|
||||
8573,1101,2022-01-07 00:00:00,day,10921.0,47.6,47.65,47.2,47.45
|
||||
8574,1101,2022-01-10 00:00:00,day,11925.0,47.45,47.5,47.0,47.3
|
||||
8575,1101,2022-01-11 00:00:00,day,11047.0,47.1,47.5,47.1,47.5
|
||||
8576,1101,2022-01-12 00:00:00,day,10817.0,47.5,47.5,47.1,47.5
|
||||
8577,1101,2022-01-13 00:00:00,day,13849.0,47.5,47.95,47.4,47.95
|
||||
8578,1101,2022-01-14 00:00:00,day,9460.0,47.85,47.85,47.45,47.6
|
||||
8579,1101,2022-01-17 00:00:00,day,9057.0,47.55,47.7,47.35,47.6
|
||||
8580,1101,2022-01-18 00:00:00,day,8089.0,47.6,47.75,47.45,47.75
|
||||
8581,1101,2022-01-19 00:00:00,day,5110.0,47.6,47.7,47.5,47.6
|
||||
8582,1101,2022-01-20 00:00:00,day,6327.0,47.55,47.7,47.45,47.5
|
||||
8583,1101,2022-01-21 00:00:00,day,9470.0,47.5,47.65,47.15,47.4
|
||||
8584,1101,2022-01-24 00:00:00,day,5475.0,47.1,47.3,47.0,47.15
|
||||
8585,1101,2022-01-25 00:00:00,day,16153.0,47.0,47.05,46.6,46.8
|
||||
8586,1101,2022-01-26 00:00:00,day,7772.0,46.7,47.0,46.55,46.85
|
||||
8587,1101,2022-02-07 00:00:00,day,17031.0,46.55,47.1,46.0,47.1
|
||||
8588,1101,2022-02-08 00:00:00,day,9741.0,47.1,47.25,46.9,46.95
|
||||
8589,1101,2022-02-09 00:00:00,day,7968.0,46.95,47.3,46.9,47.3
|
||||
8590,1101,2022-02-10 00:00:00,day,7479.0,47.15,47.55,47.05,47.55
|
||||
8591,1101,2022-02-11 00:00:00,day,6841.0,47.3,47.55,47.15,47.55
|
||||
8592,1101,2022-02-14 00:00:00,day,9136.0,47.2,47.3,46.95,47.15
|
||||
8593,1101,2022-02-15 00:00:00,day,5444.0,47.05,47.1,46.8,47.0
|
||||
8594,1101,2022-02-16 00:00:00,day,8751.0,47.0,47.15,47.0,47.0
|
||||
8595,1101,2022-02-17 00:00:00,day,10662.0,47.15,47.55,47.1,47.45
|
||||
8596,1101,2022-02-18 00:00:00,day,8781.0,47.25,47.55,47.2,47.45
|
||||
8597,1101,2022-02-21 00:00:00,day,8201.0,47.35,47.75,47.15,47.6
|
||||
8598,1101,2022-02-22 00:00:00,day,10655.0,47.4,47.7,47.1,47.7
|
||||
8599,1101,2022-02-23 00:00:00,day,8040.0,47.7,47.85,47.45,47.65
|
||||
8600,1101,2022-02-24 00:00:00,day,13124.0,47.5,47.5,47.1,47.3
|
||||
8601,1101,2022-02-25 00:00:00,day,14556.0,47.2,47.5,46.9,47.35
|
||||
"""
|
||||
|
||||
MOCK_DF = pd.read_csv(io.StringIO(MOCK_DATA), header=0, dtype={"symbol": str})
|
||||
|
||||
|
||||
class MockStorageBase:
|
||||
def __init__(self, **kwargs):
|
||||
self.df = MOCK_DF
|
||||
|
||||
|
||||
class MockCalendarStorage(MockStorageBase, CalendarStorage):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self._data = sorted(self.df["datetime"].unique())
|
||||
|
||||
@property
|
||||
def data(self) -> List[CalVT]:
|
||||
return self._data
|
||||
|
||||
def __getitem__(self, i: Union[int, slice]) -> Union[CalVT, List[CalVT]]:
|
||||
return self.data[i]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class MockInstrumentStorage(MockStorageBase, InstrumentStorage):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
instruments = {}
|
||||
for symbol, group in self.df.groupby(by="symbol"):
|
||||
start = group["datetime"].iloc[0]
|
||||
end = group["datetime"].iloc[-1]
|
||||
instruments[symbol] = [(start, end)]
|
||||
self._data = instruments
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[InstKT, InstVT]:
|
||||
return self._data
|
||||
|
||||
def __getitem__(self, k: InstKT) -> InstVT:
|
||||
return self.data[k]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class MockFeatureStorage(MockStorageBase, FeatureStorage):
|
||||
def __init__(self, instrument: str, field: str, freq: str, db_region: str = None, **kwargs): # type: ignore
|
||||
super().__init__(instrument=instrument, field=field, freq=freq, db_region=db_region, **kwargs)
|
||||
self.field = field
|
||||
calendar = sorted(self.df["datetime"].unique())
|
||||
df_calendar = pd.DataFrame(calendar, columns=["datetime"]).set_index("datetime")
|
||||
df = self.df[self.df["symbol"] == instrument]
|
||||
data_dt_field = "datetime"
|
||||
cal_df = df_calendar[
|
||||
(df_calendar.index >= df[data_dt_field].min()) & (df_calendar.index <= df[data_dt_field].max())
|
||||
]
|
||||
df = df.set_index(data_dt_field)
|
||||
df_data = df.reindex(cal_df.index)
|
||||
date_index = df_calendar.index.get_loc(df_data.index.min()) # type: ignore
|
||||
df_data.reset_index(inplace=True)
|
||||
df_data.index += date_index
|
||||
self._data = df_data
|
||||
|
||||
@property
|
||||
def data(self) -> pd.Series:
|
||||
return self._data[self.field]
|
||||
|
||||
@property
|
||||
def start_index(self) -> Union[int, None]:
|
||||
if self._data.empty:
|
||||
return None
|
||||
return self._data.index[0]
|
||||
|
||||
@property
|
||||
def end_index(self) -> Union[int, None]:
|
||||
if self._data.empty:
|
||||
return None
|
||||
# The next data appending index point will be `end_index + 1`
|
||||
return self._data.index[-1]
|
||||
|
||||
def __getitem__(self, i: Union[int, slice]) -> Union[Tuple[int, float], pd.Series]:
|
||||
df = self._data
|
||||
storage_start_index = df.index[0]
|
||||
storage_end_index = df.index[-1]
|
||||
if isinstance(i, int):
|
||||
if storage_start_index > i or i > storage_end_index:
|
||||
raise IndexError(f"{i}: start index is {storage_start_index}")
|
||||
data = self.data[i]
|
||||
return i, data
|
||||
elif isinstance(i, slice):
|
||||
start_index = storage_start_index if i.start is None else i.start
|
||||
end_index = storage_end_index if i.stop is None else i.stop
|
||||
si = max(start_index, storage_start_index)
|
||||
if si > end_index or self.field not in df.columns:
|
||||
return pd.Series(dtype=np.float32) # type: ignore
|
||||
data = df[self.field].tolist()
|
||||
result = data[si - storage_start_index : end_index - storage_start_index]
|
||||
return pd.Series(result, index=pd.RangeIndex(si, si + len(result))) # type: ignore
|
||||
else:
|
||||
raise TypeError(f"type(i) = {type(i)}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class TestMockData(unittest.TestCase):
|
||||
_setup_kwargs = {
|
||||
"calendar_provider": {
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {"backend": {"class": "MockCalendarStorage", "module_path": "qlib.tests"}},
|
||||
},
|
||||
"instrument_provider": {
|
||||
"class": "LocalInstrumentProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {"backend": {"class": "MockInstrumentStorage", "module_path": "qlib.tests"}},
|
||||
},
|
||||
"feature_provider": {
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {"backend": {"class": "MockFeatureStorage", "module_path": "qlib.tests"}},
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
|
||||
provider_uri = "Not necessary."
|
||||
init(region=REG_TW, provider_uri=provider_uri, expression_cache=None, dataset_cache=None, **cls._setup_kwargs)
|
||||
|
|
|
@ -1,17 +1,52 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from qlib.data import DatasetProvider
|
||||
from qlib.tests import TestOperatorData
|
||||
from qlib.data.data import ExpressionD
|
||||
from qlib.tests import TestOperatorData, TestMockData, MOCK_DF
|
||||
from qlib.config import C
|
||||
|
||||
|
||||
class TestElementOperator(TestMockData):
|
||||
def setUp(self) -> None:
|
||||
self.instrument = "0050"
|
||||
self.start_time = "2022-01-01"
|
||||
self.end_time = "2022-02-01"
|
||||
self.freq = "day"
|
||||
self.mock_df = MOCK_DF[MOCK_DF["symbol"] == self.instrument]
|
||||
|
||||
def test_Abs(self):
|
||||
field = "Abs($close-Ref($close, 1))"
|
||||
result = ExpressionD.expression(self.instrument, field, self.start_time, self.end_time, self.freq)
|
||||
self.assertGreaterEqual(result.min(), 0)
|
||||
result = result.to_numpy()
|
||||
prev_close = self.mock_df["close"].shift(1)
|
||||
close = self.mock_df["close"]
|
||||
change = prev_close - close
|
||||
golden = change.abs().to_numpy()
|
||||
self.assertIsNone(np.testing.assert_allclose(result, golden))
|
||||
|
||||
def test_Sign(self):
|
||||
field = "Sign($close-Ref($close, 1))"
|
||||
result = ExpressionD.expression(self.instrument, field, self.start_time, self.end_time, self.freq)
|
||||
result = result.to_numpy()
|
||||
prev_close = self.mock_df["close"].shift(1)
|
||||
close = self.mock_df["close"]
|
||||
change = close - prev_close
|
||||
change[change > 0] = 1.0
|
||||
change[change < 0] = -1.0
|
||||
golden = change.to_numpy()
|
||||
self.assertIsNone(np.testing.assert_allclose(result, golden))
|
||||
|
||||
|
||||
class TestOperatorDataSetting(TestOperatorData):
|
||||
def test_setting(self):
|
||||
self.assertEqual(len(self.instruments_d), 1)
|
||||
self.assertGreater(len(self.cal), 0)
|
||||
|
||||
|
||||
class TestElementOperator(TestOperatorData):
|
||||
class TestInstElementOperator(TestOperatorData):
|
||||
def setUp(self) -> None:
|
||||
freq = "day"
|
||||
expressions = [
|
||||
|
@ -24,6 +59,7 @@ class TestElementOperator(TestOperatorData):
|
|||
)
|
||||
self.data.columns = columns
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_abs(self):
|
||||
abs_values = self.data["abs"]
|
||||
self.assertGreater(abs_values[2], 0)
|
||||
|
|
Загрузка…
Ссылка в новой задаче