Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 130 additions & 23 deletions model/kronos.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings
from dataclasses import dataclass
import numpy as np
import pandas as pd
import torch
Expand All @@ -10,6 +12,18 @@
from model.module import *


@dataclass
class KVCacheState:
logits: torch.Tensor | None = None
hidden_states: torch.Tensor | None = None
past_kv: list[LayerCache] | None = None

def reset(self) -> None:
self.logits = None
self.hidden_states = None
self.past_kv = None


class KronosTokenizer(nn.Module, PyTorchModelHubMixin):
"""
KronosTokenizer module for tokenizing input data using a hybrid quantization approach.
Expand Down Expand Up @@ -247,7 +261,7 @@ def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_for
s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None.

Returns:
Tuple[torch.Tensor, torch.Tensor]:
tuple[torch.Tensor, torch.Tensor]:
- s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
- s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size]
"""
Expand Down Expand Up @@ -275,9 +289,13 @@ def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_for
s2_logits = self.head.cond_forward(x2)
return s1_logits, s2_logits

def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None):

def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None,
past_kv: list[LayerCache] | None = None,
max_context: int | None = None,
return_kv: bool = False):
"""
Decodes only the s1 tokens.
Decodes only the s1 tokens and optionally returns/updates KV cache state.

This method performs a forward pass to predict only s1 tokens. It returns the s1 logits
and the context representation from the Transformer, which can be used for subsequent s2 decoding.
Expand All @@ -287,25 +305,45 @@ def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None):
s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
past_kv (List[LayerCache], optional): Cached key/value tensors for each layer (step decoding).
max_context (int, optional): Maximum cache size used to trim stored kv tensors.
return_kv (bool, optional): When True, returns the per-layer caches built from this pass.

Returns:
Tuple[torch.Tensor, torch.Tensor]:
tuple[torch.Tensor, torch.Tensor, list[LayerCache] | None]:
- s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
- context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model]
- kv cache: Per-layer cache objects when requested, else None
"""
x = self.embedding([s1_ids, s2_ids])
if stamp is not None:
time_embedding = self.time_emb(stamp)
x = x + time_embedding
x = self.token_drop(x)

for layer in self.transformer:
x = layer(x, key_padding_mask=padding_mask)
kv_cache: list[LayerCache] | None = None
if past_kv is not None:
if len(past_kv) != self.n_layers:
raise ValueError(f"Expected {self.n_layers} past_kv entries, got {len(past_kv)}")
if s1_ids.size(1) != 1 or s2_ids.size(1) != 1:
raise ValueError("Step decoding expects seq_len == 1")
kv_cache = past_kv
elif return_kv:
kv_cache = [LayerCache() for _ in range(self.n_layers)]

for idx, layer in enumerate(self.transformer):
layer_cache = None if kv_cache is None else kv_cache[idx]
x = layer(
x,
key_padding_mask=padding_mask,
layer_cache=layer_cache,
max_cache_len=max_context,
)

x = self.norm(x)

s1_logits = self.head(x)
return s1_logits, x
cache_return = kv_cache if (return_kv or past_kv is not None) else None
return s1_logits, x, cache_return

def decode_s2(self, context, s1_ids, padding_mask=None):
"""
Expand Down Expand Up @@ -386,8 +424,51 @@ def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_l
return x


def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False):
def _update_kv_cache_state(
model: Kronos,
cache_state: KVCacheState,
sample_pre: torch.Tensor,
sample_post: torch.Tensor,
full_stamp: torch.Tensor,
current_seq_len: int,
max_context: int,
) -> None:
"""
Update KV cache state after generating a new step.

TODO: Implement cache rolling instead of resetting.
When `max_context` is exceeded, the cache resets and generation falls
back to recomputing the entire context each step.
"""
if current_seq_len < max_context:
stamp_step = full_stamp[:, current_seq_len:current_seq_len + 1, :].contiguous()
step_logits, h_last_norm, new_past = model.decode_s1(
sample_pre,
sample_post,
stamp_step,
past_kv=cache_state.past_kv,
max_context=max_context,
)
cache_state.logits = step_logits[:, -1, :]
cache_state.hidden_states = torch.cat([cache_state.hidden_states, h_last_norm], dim=1)
if cache_state.hidden_states.size(1) > max_context:
cache_state.hidden_states = cache_state.hidden_states[:, -max_context:, :]
cache_state.past_kv = new_past
else:
if current_seq_len + 1 == max_context + 1:
warnings.warn(
f"Generation length exceeded max_context ({max_context}). "
"KV cache now resets every step, eliminating its speedup. "
"Consider increasing max_context.",
UserWarning,
)
cache_state.reset()


def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, use_kv_cache=False):
with torch.no_grad():
cache_state = KVCacheState() if use_kv_cache else None

x = torch.clip(x, -clip, clip)

device = x.device
Expand All @@ -396,7 +477,7 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context
y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device)

x_token = tokenizer.encode(x, half=True)

initial_seq_len = x.size(1)
batch_size = x_token[0].size(0)
total_seq_len = initial_seq_len + pred_len
Expand All @@ -417,24 +498,39 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context
ran = trange
else:
ran = range

for i in ran(pred_len):
current_seq_len = initial_seq_len + i
window_len = min(current_seq_len, max_context)

if current_seq_len <= max_context:
input_tokens = [
pre_buffer[:, :window_len],
post_buffer[:, :window_len]
]
s1_window = pre_buffer[:, :window_len]
s2_window = post_buffer[:, :window_len]
else:
input_tokens = [pre_buffer, post_buffer]
s1_window = pre_buffer
s2_window = post_buffer

context_end = current_seq_len
context_start = max(0, context_end - max_context)
current_stamp = full_stamp[:, context_start:context_end, :].contiguous()

s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp)
s1_logits = s1_logits[:, -1, :]
if use_kv_cache:
if cache_state.logits is None:
if window_len <= 0:
raise RuntimeError("KV cache decoding requires a non-empty context window.")
s1_logits_full, hidden_states, past_kv = model.decode_s1(
s1_window, s2_window, current_stamp, padding_mask=None, max_context=max_context, return_kv=True
)
cache_state.logits = s1_logits_full[:, -1, :]
cache_state.hidden_states = hidden_states
cache_state.past_kv = past_kv

s1_logits = cache_state.logits
context = cache_state.hidden_states
else:
s1_logits, context, _ = model.decode_s1(s1_window, s2_window, current_stamp)
s1_logits = s1_logits[:, -1, :]

sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)

s2_logits = model.decode_s2(context, sample_pre)
Expand All @@ -453,6 +549,17 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context
pre_buffer[:, -1] = sample_pre.squeeze(-1)
post_buffer[:, -1] = sample_post.squeeze(-1)

if use_kv_cache:
_update_kv_cache_state(
model=model,
cache_state=cache_state,
sample_pre=sample_pre,
sample_post=sample_post,
full_stamp=full_stamp,
current_seq_len=current_seq_len,
max_context=max_context,
)

full_pre = torch.cat([x_token[0], generated_pre], dim=1)
full_post = torch.cat([x_token[1], generated_post], dim=1)

Expand Down Expand Up @@ -495,18 +602,18 @@ def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5):
self.tokenizer = self.tokenizer.to(self.device)
self.model = self.model.to(self.device)

def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose):
def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, use_kv_cache=False):

x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device)
x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device)
y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device)

preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len,
self.clip, T, top_k, top_p, sample_count, verbose)
self.clip, T, top_k, top_p, sample_count, verbose, use_kv_cache=use_kv_cache)
preds = preds[:, -pred_len:, :]
return preds

def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, use_kv_cache=False):

if not isinstance(df, pd.DataFrame):
raise ValueError("Input must be a pandas DataFrame.")
Expand Down Expand Up @@ -540,7 +647,7 @@ def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=
x_stamp = x_stamp[np.newaxis, :]
y_stamp = y_stamp[np.newaxis, :]

preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose)
preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, use_kv_cache=use_kv_cache)

preds = preds.squeeze(0)
preds = preds * (x_std + 1e-5) + x_mean
Expand All @@ -549,7 +656,7 @@ def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=
return pred_df


def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, use_kv_cache=False):
"""
Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len).

Expand Down Expand Up @@ -639,7 +746,7 @@ def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T
x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(np.float32) # (B, seq_len, time_feat)
y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(np.float32) # (B, pred_len, time_feat)

preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose)
preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose, use_kv_cache=use_kv_cache)
# preds: (B, pred_len, feat)

pred_dfs = []
Expand Down
Loading