Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor online serving rolling api #539

Merged
merged 1 commit into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config
1 change: 1 addition & 0 deletions examples/model_rolling/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
xgboost
4 changes: 3 additions & 1 deletion qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,9 +570,11 @@ def get_pre_trading_date(trading_date, future=False):


def transform_end_date(end_date=None, freq="day"):
"""get previous trading date
"""handle the end date with various format

If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned.
Otherwise, returns the end_date

----------
end_date: str
end trading date
Expand Down
24 changes: 9 additions & 15 deletions qlib/workflow/online/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from qlib.data.data import D
from qlib.log import get_module_logger
from qlib.model.ens.group import RollingGroup
from qlib.utils import transform_end_date
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import Collector, RecorderCollector
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
task_template = [task_template]
self.task_template = task_template
self.rg = rolling_gen
assert issubclass(self.rg.__class__, RollingGen), "The rolling strategy relies on the feature if RollingGen"
self.tool = OnlineToolR(self.exp_name)
self.ta = TimeAdjuster()

Expand Down Expand Up @@ -174,28 +176,20 @@ def prepare_tasks(self, cur_time) -> List[dict]:
Returns:
List[dict]: a list of new tasks.
"""
# TODO: filter recorders by latest test segments is not a necessary
latest_records, max_test = self._list_latest(self.tool.online_models())
if max_test is None:
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
calendar_latest = transform_end_date(cur_time)
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []
tasks_tmp = []
for rec in latest_records:
task = rec.load_object("task")
old_tasks.append(deepcopy(task))
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
tasks_tmp.append(task)
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
return new_tasks
return []
res = []
for rec in latest_records:
task = rec.load_object("task")
res.extend(self.rg.gen_following_tasks(task, calendar_latest))
return res

def _list_latest(self, rec_list: List[Recorder]):
"""
Expand Down
2 changes: 2 additions & 0 deletions qlib/workflow/online/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"
if to_date == None:
to_date = D.calendar(freq=freq)[-1]
self.to_date = pd.Timestamp(to_date)
# FIXME: it will raise error when running routine with delay trainer
# should we use another predicition updater for delay trainer?
self.old_pred = record.load_object("pred.pkl")
self.last_end = self.old_pred.index.get_level_values("datetime").max()

Expand Down
104 changes: 66 additions & 38 deletions qlib/workflow/task/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import abc
import copy
import pandas as pd
from typing import List, Union, Callable

from qlib.utils import transform_end_date
Expand Down Expand Up @@ -139,6 +140,53 @@ def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Unio
self.test_key = "test"
self.train_key = "train"

def _update_task_segs(self, task, segs):
# update segments of this task
task["dataset"]["kwargs"]["segments"] = copy.deepcopy(segs)
if self.ds_extra_mod_func is not None:
self.ds_extra_mod_func(task, self)

def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]:
"""
generating following rolling tasks for `task` until test_end

Parameters
----------
task : dict
Qlib task format
test_end : pd.Timestamp
the latest rolling task includes `test_end`

Returns
-------
List[dict]:
the following tasks of `task`(`task` itself is excluded)
"""
t = copy.deepcopy(task)
prev_seg = t["dataset"]["kwargs"]["segments"]
while True:
segments = {}
try:
for k, seg in prev_seg.items():
# decide how to shift
# expanding only for train data, the segments size of test data and valid data won't change
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# shift the segments data
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break

prev_seg = segments
self._update_task_segs(t, segments)
yield t

def generate(self, task: dict) -> List[dict]:
"""
Converting the task into a rolling task.
Expand Down Expand Up @@ -191,43 +239,23 @@ def generate(self, task: dict) -> List[dict]:
"""
res = []

prev_seg = None
test_end = None
while True:
t = copy.deepcopy(task)

# calculate segments
if prev_seg is None:
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = transform_end_date(segments[self.test_key][1])
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
else:
segments = {}
try:
for k, seg in prev_seg.items():
# decide how to shift
# expanding only for train data, the segments size of test data and valid data won't change
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# shift the segments data
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break
t = copy.deepcopy(task)

# update segments of this task
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
prev_seg = segments
if self.ds_extra_mod_func is not None:
self.ds_extra_mod_func(t, self)
res.append(t)
# calculate segments

# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = transform_end_date(segments[self.test_key][1])
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))

# update segments of this task
self._update_task_segs(t, segments)

res.append(t)

# Update the following rolling
res.extend(self.gen_following_tasks(t, test_end))
return res