-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
rolling_benchmark.py
41 lines (31 loc) · 1.23 KB
/
rolling_benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from pathlib import Path
from typing import Union
import fire
from qlib import auto_init
from qlib.contrib.rolling.base import Rolling
from qlib.tests.data import GetData
DIRNAME = Path(__file__).absolute().resolve().parent
class RollingBenchmark(Rolling):
# The config in the README.md
CONF_LIST = [DIRNAME / "workflow_config_linear_Alpha158.yaml", DIRNAME / "workflow_config_lightgbm_Alpha158.yaml"]
DEFAULT_CONF = CONF_LIST[0]
def __init__(self, conf_path: Union[str, Path] = DEFAULT_CONF, horizon=20, **kwargs) -> None:
# This code is for being compatible with the previous old code
conf_path = Path(conf_path)
super().__init__(conf_path=conf_path, horizon=horizon, **kwargs)
for f in self.CONF_LIST:
if conf_path.samefile(f):
break
else:
self.logger.warning("Model type is not in the benchmark!")
if __name__ == "__main__":
kwargs = {}
if os.environ.get("PROVIDER_URI", "") == "":
GetData().qlib_data(exists_skip=True)
else:
kwargs["provider_uri"] = os.environ["PROVIDER_URI"]
auto_init(**kwargs)
fire.Fire(RollingBenchmark)