diff --git a/model/kronos.py b/model/kronos.py index e7ebba60..7bee5321 100644 --- a/model/kronos.py +++ b/model/kronos.py @@ -1,3 +1,5 @@ +import warnings +from dataclasses import dataclass import numpy as np import pandas as pd import torch @@ -10,6 +12,18 @@ from model.module import * +@dataclass +class KVCacheState: + logits: torch.Tensor | None = None + hidden_states: torch.Tensor | None = None + past_kv: list[LayerCache] | None = None + + def reset(self) -> None: + self.logits = None + self.hidden_states = None + self.past_kv = None + + class KronosTokenizer(nn.Module, PyTorchModelHubMixin): """ KronosTokenizer module for tokenizing input data using a hybrid quantization approach. @@ -247,7 +261,7 @@ def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_for s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None. Returns: - Tuple[torch.Tensor, torch.Tensor]: + tuple[torch.Tensor, torch.Tensor]: - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size] """ @@ -275,9 +289,13 @@ def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_for s2_logits = self.head.cond_forward(x2) return s1_logits, s2_logits - def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None): + + def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None, + past_kv: list[LayerCache] | None = None, + max_context: int | None = None, + return_kv: bool = False): """ - Decodes only the s1 tokens. + Decodes only the s1 tokens and optionally returns/updates KV cache state. This method performs a forward pass to predict only s1 tokens. It returns the s1 logits and the context representation from the Transformer, which can be used for subsequent s2 decoding. @@ -287,11 +305,15 @@ def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None): s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. + past_kv (List[LayerCache], optional): Cached key/value tensors for each layer (step decoding). + max_context (int, optional): Maximum cache size used to trim stored kv tensors. + return_kv (bool, optional): When True, returns the per-layer caches built from this pass. Returns: - Tuple[torch.Tensor, torch.Tensor]: + tuple[torch.Tensor, torch.Tensor, list[LayerCache] | None]: - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model] + - kv cache: Per-layer cache objects when requested, else None """ x = self.embedding([s1_ids, s2_ids]) if stamp is not None: @@ -299,13 +321,29 @@ def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None): x = x + time_embedding x = self.token_drop(x) - for layer in self.transformer: - x = layer(x, key_padding_mask=padding_mask) + kv_cache: list[LayerCache] | None = None + if past_kv is not None: + if len(past_kv) != self.n_layers: + raise ValueError(f"Expected {self.n_layers} past_kv entries, got {len(past_kv)}") + if s1_ids.size(1) != 1 or s2_ids.size(1) != 1: + raise ValueError("Step decoding expects seq_len == 1") + kv_cache = past_kv + elif return_kv: + kv_cache = [LayerCache() for _ in range(self.n_layers)] + + for idx, layer in enumerate(self.transformer): + layer_cache = None if kv_cache is None else kv_cache[idx] + x = layer( + x, + key_padding_mask=padding_mask, + layer_cache=layer_cache, + max_cache_len=max_context, + ) x = self.norm(x) - s1_logits = self.head(x) - return s1_logits, x + cache_return = kv_cache if (return_kv or past_kv is not None) else None + return s1_logits, x, cache_return def decode_s2(self, context, s1_ids, padding_mask=None): """ @@ -386,8 +424,51 @@ def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_l return x -def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False): +def _update_kv_cache_state( + model: Kronos, + cache_state: KVCacheState, + sample_pre: torch.Tensor, + sample_post: torch.Tensor, + full_stamp: torch.Tensor, + current_seq_len: int, + max_context: int, +) -> None: + """ + Update KV cache state after generating a new step. + + TODO: Implement cache rolling instead of resetting. + When `max_context` is exceeded, the cache resets and generation falls + back to recomputing the entire context each step. + """ + if current_seq_len < max_context: + stamp_step = full_stamp[:, current_seq_len:current_seq_len + 1, :].contiguous() + step_logits, h_last_norm, new_past = model.decode_s1( + sample_pre, + sample_post, + stamp_step, + past_kv=cache_state.past_kv, + max_context=max_context, + ) + cache_state.logits = step_logits[:, -1, :] + cache_state.hidden_states = torch.cat([cache_state.hidden_states, h_last_norm], dim=1) + if cache_state.hidden_states.size(1) > max_context: + cache_state.hidden_states = cache_state.hidden_states[:, -max_context:, :] + cache_state.past_kv = new_past + else: + if current_seq_len + 1 == max_context + 1: + warnings.warn( + f"Generation length exceeded max_context ({max_context}). " + "KV cache now resets every step, eliminating its speedup. " + "Consider increasing max_context.", + UserWarning, + ) + cache_state.reset() + + +def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, use_kv_cache=False): with torch.no_grad(): + cache_state = KVCacheState() if use_kv_cache else None + x = torch.clip(x, -clip, clip) device = x.device @@ -396,7 +477,7 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device) x_token = tokenizer.encode(x, half=True) - + initial_seq_len = x.size(1) batch_size = x_token[0].size(0) total_seq_len = initial_seq_len + pred_len @@ -417,24 +498,39 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context ran = trange else: ran = range + for i in ran(pred_len): current_seq_len = initial_seq_len + i window_len = min(current_seq_len, max_context) if current_seq_len <= max_context: - input_tokens = [ - pre_buffer[:, :window_len], - post_buffer[:, :window_len] - ] + s1_window = pre_buffer[:, :window_len] + s2_window = post_buffer[:, :window_len] else: - input_tokens = [pre_buffer, post_buffer] + s1_window = pre_buffer + s2_window = post_buffer context_end = current_seq_len context_start = max(0, context_end - max_context) current_stamp = full_stamp[:, context_start:context_end, :].contiguous() - s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp) - s1_logits = s1_logits[:, -1, :] + if use_kv_cache: + if cache_state.logits is None: + if window_len <= 0: + raise RuntimeError("KV cache decoding requires a non-empty context window.") + s1_logits_full, hidden_states, past_kv = model.decode_s1( + s1_window, s2_window, current_stamp, padding_mask=None, max_context=max_context, return_kv=True + ) + cache_state.logits = s1_logits_full[:, -1, :] + cache_state.hidden_states = hidden_states + cache_state.past_kv = past_kv + + s1_logits = cache_state.logits + context = cache_state.hidden_states + else: + s1_logits, context, _ = model.decode_s1(s1_window, s2_window, current_stamp) + s1_logits = s1_logits[:, -1, :] + sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True) s2_logits = model.decode_s2(context, sample_pre) @@ -453,6 +549,17 @@ def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context pre_buffer[:, -1] = sample_pre.squeeze(-1) post_buffer[:, -1] = sample_post.squeeze(-1) + if use_kv_cache: + _update_kv_cache_state( + model=model, + cache_state=cache_state, + sample_pre=sample_pre, + sample_post=sample_post, + full_stamp=full_stamp, + current_seq_len=current_seq_len, + max_context=max_context, + ) + full_pre = torch.cat([x_token[0], generated_pre], dim=1) full_post = torch.cat([x_token[1], generated_post], dim=1) @@ -495,18 +602,18 @@ def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5): self.tokenizer = self.tokenizer.to(self.device) self.model = self.model.to(self.device) - def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose): + def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, use_kv_cache=False): x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device) x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device) y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device) preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len, - self.clip, T, top_k, top_p, sample_count, verbose) + self.clip, T, top_k, top_p, sample_count, verbose, use_kv_cache=use_kv_cache) preds = preds[:, -pred_len:, :] return preds - def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True): + def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, use_kv_cache=False): if not isinstance(df, pd.DataFrame): raise ValueError("Input must be a pandas DataFrame.") @@ -540,7 +647,7 @@ def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p= x_stamp = x_stamp[np.newaxis, :] y_stamp = y_stamp[np.newaxis, :] - preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose) + preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, use_kv_cache=use_kv_cache) preds = preds.squeeze(0) preds = preds * (x_std + 1e-5) + x_mean @@ -549,7 +656,7 @@ def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p= return pred_df - def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True): + def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, use_kv_cache=False): """ Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len). @@ -639,7 +746,7 @@ def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(np.float32) # (B, seq_len, time_feat) y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(np.float32) # (B, pred_len, time_feat) - preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose) + preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose, use_kv_cache=use_kv_cache) # preds: (B, pred_len, feat) pred_dfs = [] diff --git a/model/module.py b/model/module.py index f2a05158..7ec6a892 100644 --- a/model/module.py +++ b/model/module.py @@ -1,4 +1,5 @@ import math +from dataclasses import dataclass from einops import rearrange, reduce import torch @@ -7,6 +8,43 @@ import torch.nn.functional as F +@dataclass +class LayerCache: + k: torch.Tensor | None = None + v: torch.Tensor | None = None + + @property + def is_empty(self) -> bool: + return self.k is None or self.v is None + + @property + def seq_len(self) -> int: + if self.is_empty: + return 0 + return self.k.size(2) + + def get(self) -> tuple[torch.Tensor, torch.Tensor]: + if self.is_empty: + raise ValueError("LayerCache is empty.") + return self.k, self.v + + def append(self, k: torch.Tensor, v: torch.Tensor, max_cache_len: int | None = None) -> None: + if self.is_empty: + self.k = k.contiguous() + self.v = v.contiguous() + else: + self.k = torch.cat([self.k, k], dim=2).contiguous() + self.v = torch.cat([self.v, v], dim=2).contiguous() + + if max_cache_len is not None and self.seq_len > max_cache_len: + self.k = self.k[:, :, -max_cache_len:, :].contiguous() + self.v = self.v[:, :, -max_cache_len:, :].contiguous() + + def reset(self) -> None: + self.k = None + self.v = None + + class DifferentiableEntropyFunction(Function): @staticmethod def forward(ctx, zq, basis, K, eps): @@ -300,17 +338,39 @@ def _update_cos_sin_cache(self, x, seq_len): self.sin_cached = emb.sin()[None, None, :, :] return self.cos_cached, self.sin_cached - def forward(self, q, k): - cos, sin = self._update_cos_sin_cache(q, q.shape[-2]) - return ( - (q * cos) + (self._rotate_half(q) * sin), - (k * cos) + (self._rotate_half(k) * sin), - ) + def forward(self, q, k, position_offset: int = 0): + """ + This implementation uses window-relative RoPE positions, not absolute. + position_offset represents the position within the current sliding window, + always starting from 0. + This matches the baseline training behavior where each sliding window + is processed independently and allows `_forward_step` to append cached + segments without phase discontinuities. + """ + # TODO: Simplify by merging with `apply_rotary` after refactoring tests (affected by numerical instability) + seq_len = q.shape[-2] + cos, sin = self._select_cos_sin(q, position_offset, seq_len) + return self._apply_with_cos_sin(q, cos, sin), self._apply_with_cos_sin(k, cos, sin) def _rotate_half(self, x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) + def apply_rotary(self, x, position_offset: int = 0): + seq_len = x.shape[-2] + cos, sin = self._select_cos_sin(x, position_offset, seq_len) + return self._apply_with_cos_sin(x, cos, sin) + + def _select_cos_sin(self, x, position_offset: int, seq_len: int): + total_len = position_offset + seq_len + cos_full, sin_full = self._update_cos_sin_cache(x, total_len) + cos = cos_full[:, :, position_offset:position_offset + seq_len, :] + sin = sin_full[:, :, position_offset:position_offset + seq_len, :] + return cos, sin + + def _apply_with_cos_sin(self, x, cos, sin): + return (x * cos) + (self._rotate_half(x) * sin) + class MultiHeadAttentionWithRoPE(nn.Module): def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0): @@ -327,13 +387,29 @@ def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0): self.attn_dropout_p = attn_dropout_p self.resid_dropout = nn.Dropout(resid_dropout_p) - def forward(self, x, key_padding_mask=None): + def forward(self, x, key_padding_mask=None, layer_cache: LayerCache | None = None, max_cache_len=None): batch_size, seq_len, _ = x.shape q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) + incremental = layer_cache is not None and not layer_cache.is_empty + + if incremental: + attn_output = self._forward_with_cache(q, k, v, seq_len, layer_cache) + else: + attn_output = self._forward_no_cache(q, k, v, key_padding_mask, seq_len) + + if layer_cache is not None: + layer_cache.append(k, v, max_cache_len=max_cache_len) + + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) + out = self.resid_dropout(self.out_proj(attn_output)) + + return out + + def _forward_no_cache(self, q, k, v, key_padding_mask, seq_len): q, k = self.rotary(q, k) if key_padding_mask is not None: @@ -342,15 +418,33 @@ def forward(self, x, key_padding_mask=None): else: attn_mask = None - attn_output = F.scaled_dot_product_attention( + return F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_dropout_p if self.training else 0.0, - is_causal=True + is_causal=True, ) - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) - return self.resid_dropout(self.out_proj(attn_output)) + def _forward_with_cache(self, q, k, v, seq_len, layer_cache: LayerCache): + cache_k, cache_v = layer_cache.get() + cache_len = cache_k.size(2) + if seq_len != 1: + raise ValueError("Layer cache expects seq_len == 1 during incremental decoding.") + + q_rot, k_new_rot = self.rotary.forward(q, k, position_offset=cache_len) + cached_k_rot = self.rotary.apply_rotary(cache_k) + + k_total = torch.cat([cached_k_rot, k_new_rot], dim=2) + v_total = torch.cat([cache_v, v], dim=2) + + return F.scaled_dot_product_attention( + q_rot, + k_total, + v_total, + attn_mask=None, + dropout_p=self.attn_dropout_p if self.training else 0.0, + is_causal=False, + ) class MultiHeadCrossAttentionWithRoPE(nn.Module): @@ -470,10 +564,15 @@ def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropou self.norm2 = RMSNorm(d_model) self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p) - def forward(self, x, key_padding_mask=None): + def forward(self, x, key_padding_mask=None, layer_cache: LayerCache | None = None, max_cache_len=None): residual = x x = self.norm1(x) - attn_out = self.self_attn(x, key_padding_mask=key_padding_mask) + attn_out = self.self_attn( + x, + key_padding_mask=key_padding_mask, + layer_cache=layer_cache, + max_cache_len=max_cache_len, + ) x = residual + attn_out residual = x diff --git a/tests/test_kronos_regression.py b/tests/test_kronos_regression.py index 7fccbffd..06dd1909 100644 --- a/tests/test_kronos_regression.py +++ b/tests/test_kronos_regression.py @@ -6,6 +6,7 @@ import pytest import torch from tqdm import tqdm +import itertools from model import Kronos, KronosPredictor, KronosTokenizer @@ -42,8 +43,8 @@ def set_seed(seed: int) -> None: torch.backends.cudnn.benchmark = False -@pytest.mark.parametrize("context_len", TEST_CTX_LEN) -def test_kronos_predictor_regression(context_len): +@pytest.mark.parametrize("context_len, use_kv_cache", itertools.product(TEST_CTX_LEN, [False, True])) +def test_kronos_predictor_regression(context_len, use_kv_cache): set_seed(SEED) expected_output_path = OUTPUT_DATA_DIR / f"regression_output_{context_len}.csv" @@ -77,18 +78,20 @@ def test_kronos_predictor_regression(context_len): top_p=1.0, verbose=False, sample_count=1, + use_kv_cache=use_kv_cache, ) obtained = pred_df[FEATURE_NAMES].to_numpy(dtype=np.float32) abs_diff = np.abs(obtained - expected) rel_diff = abs_diff / (np.abs(expected) + 1e-9) - print(f"Abs diff: {np.max(abs_diff)}, Rel diff: {np.max(rel_diff)}") + print(f"Abs diff: {np.max(abs_diff)}, Rel diff: {np.max(rel_diff):.6e}") np.testing.assert_allclose(obtained, expected, rtol=REL_TOLERANCE) @pytest.mark.parametrize("context_len, expected_mse", zip(MSE_CTX_LEN, MSE_EXPECTED)) -def test_kronos_predictor_mse(context_len, expected_mse): +@pytest.mark.parametrize("use_kv_cache", [False, True]) +def test_kronos_predictor_mse(context_len, expected_mse, use_kv_cache): set_seed(SEED) df = pd.read_csv(INPUT_DATA_PATH, parse_dates=["timestamps"]) @@ -111,7 +114,7 @@ def test_kronos_predictor_mse(context_len, expected_mse): mse_values = [] sample_indices = sampled_rows.index.to_list() with torch.no_grad(): - for row_idx in tqdm(sample_indices): + for row_idx in tqdm(sample_indices, disable=True): context_slice = df.iloc[row_idx - context_len : row_idx].copy() future_slice = df.iloc[row_idx : row_idx + MSE_PRED_LEN].copy() @@ -125,6 +128,7 @@ def test_kronos_predictor_mse(context_len, expected_mse): top_p=1.0, verbose=False, sample_count=1, + use_kv_cache=use_kv_cache, ) obtained = pred_df[MSE_FEATURE_NAMES].to_numpy(dtype=np.float32) @@ -135,6 +139,6 @@ def test_kronos_predictor_mse(context_len, expected_mse): mse = np.mean(mse_values).item() mse_diff = mse - expected_mse - print(f"Average MSE: {mse} (Diff vs expected: {mse_diff:+})") + print(f"Average MSE: {mse} (Diff vs expected: {mse_diff:+.6e})") assert abs(mse_diff) <= MSE_TOLERANCE, f"MSE {mse} differs from expected {expected_mse}" diff --git a/tests/test_kv_cache_equivalence.py b/tests/test_kv_cache_equivalence.py new file mode 100644 index 00000000..acc33511 --- /dev/null +++ b/tests/test_kv_cache_equivalence.py @@ -0,0 +1,88 @@ +import random +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +import torch + +from model import Kronos, KronosPredictor, KronosTokenizer + + +TEST_DATA_ROOT = Path(__file__).parent / "data" +INPUT_DATA_PATH = TEST_DATA_ROOT / "regression_input.csv" + +# Reuse same revisions as regression tests +MODEL_REVISION = "901c26c1332695a2a8f243eb2f37243a37bea320" +TOKENIZER_REVISION = "0e0117387f39004a9016484a186a908917e22426" +DEVICE = "cpu" +SEED = 123 +MAX_CTX_LEN = 512 + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.backends.cudnn.is_available(): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +@pytest.mark.parametrize("context_len,pred_len", [(256, 8), (512, 8)]) +def test_kv_cache_equivalence(context_len, pred_len): + set_seed(SEED) + + df = pd.read_csv(INPUT_DATA_PATH, parse_dates=["timestamps"]) + + if df.shape[0] < context_len + pred_len: + raise ValueError("Example data does not contain enough rows for the equivalence test.") + + context_df = df.iloc[:context_len].copy() + x_timestamp = context_df["timestamps"].reset_index(drop=True) + future_timestamp = df["timestamps"].iloc[context_len:context_len + pred_len].reset_index(drop=True) + + features = ["open", "high", "low", "close", "volume", "amount"] + context_features = context_df[features].reset_index(drop=True) + + tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base", revision=TOKENIZER_REVISION) + model = Kronos.from_pretrained("NeoQuasar/Kronos-small", revision=MODEL_REVISION) + tokenizer.eval() + model.eval() + + predictor = KronosPredictor(model, tokenizer, device=DEVICE, max_context=MAX_CTX_LEN) + + with torch.no_grad(): + pred_df_no_cache = predictor.predict( + df=context_features, + x_timestamp=x_timestamp, + y_timestamp=future_timestamp, + pred_len=pred_len, + T=1.0, + top_k=1, + top_p=1.0, + verbose=False, + sample_count=1, + use_kv_cache=False, + ) + + pred_df_cache = predictor.predict( + df=context_features, + x_timestamp=x_timestamp, + y_timestamp=future_timestamp, + pred_len=pred_len, + T=1.0, + top_k=1, + top_p=1.0, + verbose=False, + sample_count=1, + use_kv_cache=True, + ) + + np.testing.assert_allclose( + pred_df_no_cache[features].to_numpy(dtype=np.float32), + pred_df_cache[features].to_numpy(dtype=np.float32), + rtol=0, + atol=0, + ) +