Skip to content

Commit 3a7c498

Browse files
deepanshululladeepanshu
andauthored
Add GitlabPromptCache and enable subfolder access (#15712)
* Add GitlabPromptCache and enable subfolder access * Add GitlabPromptCache and enable subfolder access * Add GitlabPromptCache and enable subfolder access --------- Co-authored-by: deepanshu <[email protected]>
1 parent 647f2f5 commit 3a7c498

File tree

4 files changed

+554
-28
lines changed

4 files changed

+554
-28
lines changed

litellm/integrations/gitlab/__init__.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@
88
from litellm.types.prompts.init_prompts import SupportedPromptIntegrations
99
from litellm.integrations.custom_prompt_management import CustomPromptManagement
1010
from litellm.types.prompts.init_prompts import PromptSpec, PromptLiteLLMParams
11-
from .gitlab_prompt_manager import GitLabPromptManager
11+
from .gitlab_prompt_manager import GitLabPromptManager, GitLabPromptCache
1212

1313
# Global instances
1414
global_gitlab_config: Optional[dict] = None
1515

1616

1717
def set_global_gitlab_config(config: dict) -> None:
1818
"""
19-
Set the global BitBucket configuration for prompt management.
19+
Set the global gitlab configuration for prompt management.
2020
2121
Args:
22-
config: Dictionary containing BitBucket configuration
23-
- workspace: BitBucket workspace name
22+
config: Dictionary containing gitlab configuration
23+
- workspace: gitlab workspace name
2424
- repository: Repository name
25-
- access_token: BitBucket access token
25+
- access_token: gitlab access token
2626
- branch: Branch to fetch prompts from (default: main)
2727
"""
2828
import litellm
@@ -34,24 +34,24 @@ def prompt_initializer(
3434
litellm_params: "PromptLiteLLMParams", prompt_spec: "PromptSpec"
3535
) -> "CustomPromptManagement":
3636
"""
37-
Initialize a prompt from a BitBucket repository.
37+
Initialize a prompt from a Gitlab repository.
3838
"""
3939
gitlab_config = getattr(litellm_params, "gitlab_config", None)
4040
prompt_id = getattr(litellm_params, "prompt_id", None)
4141

4242

4343
if not gitlab_config:
4444
raise ValueError(
45-
"bitbucket_config is required for BitBucket prompt integration"
45+
"gitlab_config is required for gitlab prompt integration"
4646
)
4747

4848
try:
49-
bitbucket_prompt_manager = GitLabPromptManager(
49+
gitlab_prompt_manager = GitLabPromptManager(
5050
gitlab_config=gitlab_config,
5151
prompt_id=prompt_id,
5252
)
5353

54-
return bitbucket_prompt_manager
54+
return gitlab_prompt_manager
5555
except Exception as e:
5656
raise e
5757

@@ -90,6 +90,7 @@ def _gitlab_prompt_initializer(
9090
# Export public API
9191
__all__ = [
9292
"GitLabPromptManager",
93+
"GitLabPromptCache",
9394
"set_global_gitlab_config",
9495
"global_gitlab_config",
9596
]

litellm/integrations/gitlab/gitlab_prompt_manager.py

Lines changed: 170 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,24 @@
1212
)
1313
from litellm.types.llms.openai import AllMessageValues
1414
from litellm.types.utils import StandardCallbackDynamicParams
15-
1615
from litellm.integrations.gitlab.gitlab_client import GitLabClient
1716

1817

18+
GITLAB_PREFIX = "gitlab::"
19+
20+
def encode_prompt_id(raw_id: str) -> str:
21+
"""Convert GitLab path IDs like 'invoice/extract' → 'gitlab::invoice::extract'"""
22+
if raw_id.startswith(GITLAB_PREFIX):
23+
return raw_id # already encoded
24+
return f"{GITLAB_PREFIX}{raw_id.replace('/', '::')}"
25+
26+
def decode_prompt_id(encoded_id: str) -> str:
27+
"""Convert 'gitlab::invoice::extract' → 'invoice/extract'"""
28+
if not encoded_id.startswith(GITLAB_PREFIX):
29+
return encoded_id
30+
return encoded_id[len(GITLAB_PREFIX):].replace("::", "/")
31+
32+
1933
class GitLabPromptTemplate:
2034
def __init__(
2135
self,
@@ -87,6 +101,7 @@ def __init__(
87101

88102
def _id_to_repo_path(self, prompt_id: str) -> str:
89103
"""Map a prompt_id to a repo path (respects prompts_path and adds .prompt)."""
104+
prompt_id = decode_prompt_id(prompt_id)
90105
if self.prompts_path:
91106
return f"{self.prompts_path}/{prompt_id}.prompt"
92107
return f"{prompt_id}.prompt"
@@ -101,26 +116,27 @@ def _repo_path_to_id(self, repo_path: str) -> str:
101116
path = path[len(self.prompts_path.strip("/")) + 1 :]
102117
if path.endswith(".prompt"):
103118
path = path[: -len(".prompt")]
104-
return path
119+
return encode_prompt_id(path)
105120

106121
# ---------- loading ----------
107122

108123
def _load_prompt_from_gitlab(self, prompt_id: str, *, ref: Optional[str] = None) -> None:
109124
"""Load a specific .prompt file from GitLab (scoped under prompts_path if set)."""
110125
try:
126+
# prompt_id = decode_prompt_id(prompt_id)
111127
file_path = self._id_to_repo_path(prompt_id)
112128
prompt_content = self.gitlab_client.get_file_content(file_path, ref=ref)
113129
if prompt_content:
114130
template = self._parse_prompt_file(prompt_content, prompt_id)
115131
self.prompts[prompt_id] = template
116132
except Exception as e:
117-
raise Exception(f"Failed to load prompt '{prompt_id}' from GitLab: {e}")
133+
raise Exception(f"Failed to load prompt '{encode_prompt_id(prompt_id)}' from GitLab: {e}")
118134

119135
def load_all_prompts(self, *, recursive: bool = True) -> List[str]:
120136
"""
121137
Eagerly load all .prompt files from prompts_path. Returns loaded IDs.
122138
"""
123-
files = self.list_templates(recursive=recursive) # reuse logic
139+
files = self.list_templates(recursive=recursive)
124140
loaded: List[str] = []
125141
for pid in files:
126142
if pid not in self.prompts:
@@ -195,9 +211,6 @@ def get_template(self, template_id: str) -> Optional[GitLabPromptTemplate]:
195211
return self.prompts.get(template_id)
196212

197213
def list_templates(self, *, recursive: bool = True) -> List[str]:
198-
"""
199-
List available prompt IDs discovered under prompts_path (no extension, relative to prompts_path).
200-
"""
201214
"""
202215
List available prompt IDs under prompts_path (no extension).
203216
Compatible with both list_files signatures:
@@ -248,7 +261,7 @@ class GitLabPromptManager(CustomPromptManagement):
248261
"access_token": "glpat_***",
249262
"tag": "v1.2.3", # optional; takes precedence
250263
"branch": "main", # default fallback
251-
"prompts_path": "prompts/chat" # <--- NEW
264+
"prompts_path": "prompts/chat"
252265
}
253266
"""
254267

@@ -438,9 +451,11 @@ def _compile_prompt_helper(
438451
prompt_version: Optional[int] = None,
439452
) -> PromptManagementClient:
440453
try:
441-
if prompt_id not in self.prompt_manager.prompts:
454+
decoded_id = decode_prompt_id(prompt_id)
455+
if decoded_id not in self.prompt_manager.prompts:
442456
git_ref = getattr(dynamic_callback_params, "extra", {}).get("git_ref") if hasattr(dynamic_callback_params, "extra") else None
443-
self.prompt_manager._load_prompt_from_gitlab(prompt_id, ref=git_ref)
457+
self.prompt_manager._load_prompt_from_gitlab(decoded_id, ref=git_ref)
458+
444459

445460
rendered_prompt, prompt_metadata = self.get_prompt_template(
446461
prompt_id, prompt_variables
@@ -486,3 +501,148 @@ def get_chat_completion_prompt(
486501
prompt_label,
487502
prompt_version,
488503
)
504+
505+
506+
class GitLabPromptCache:
507+
"""
508+
Cache all .prompt files from a GitLab repo into memory.
509+
510+
- Keys are the *repo file paths* (e.g. "prompts/chat/greet/hi.prompt")
511+
mapped to JSON-like dicts containing content + metadata.
512+
- Also exposes a by-ID view (ID == path relative to prompts_path without ".prompt",
513+
e.g. "greet/hi").
514+
515+
Usage:
516+
517+
cfg = {
518+
"project": "group/subgroup/repo",
519+
"access_token": "glpat_***",
520+
"prompts_path": "prompts/chat", # optional, can be empty for repo root
521+
# "branch": "main", # default is "main"
522+
# "tag": "v1.2.3", # takes precedence over branch
523+
# "base_url": "https://gitlab.com/api/v4" # default
524+
}
525+
526+
cache = GitLabPromptCache(cfg)
527+
cache.load_all() # fetch + parse all .prompt files
528+
529+
print(cache.list_files()) # repo file paths
530+
print(cache.list_ids()) # template IDs relative to prompts_path
531+
532+
prompt_json = cache.get_by_file("prompts/chat/greet/hi.prompt")
533+
prompt_json2 = cache.get_by_id("greet/hi")
534+
535+
# If GitLab content changes and you want to refresh:
536+
cache.reload() # re-scan and refresh all
537+
"""
538+
539+
def __init__(
540+
self,
541+
gitlab_config: Dict[str, Any],
542+
*,
543+
ref: Optional[str] = None,
544+
gitlab_client: Optional[GitLabClient] = None,
545+
) -> None:
546+
# Build a PromptManager (which internally builds TemplateManager + Client)
547+
self.prompt_manager = GitLabPromptManager(
548+
gitlab_config=gitlab_config,
549+
prompt_id=None,
550+
ref=ref,
551+
gitlab_client=gitlab_client,
552+
)
553+
self.template_manager: GitLabTemplateManager = self.prompt_manager.prompt_manager
554+
555+
# In-memory stores
556+
self._by_file: Dict[str, Dict[str, Any]] = {}
557+
self._by_id: Dict[str, Dict[str, Any]] = {}
558+
559+
# -------------------------
560+
# Public API
561+
# -------------------------
562+
563+
def load_all(self, *, recursive: bool = True) -> Dict[str, Dict[str, Any]]:
564+
"""
565+
Scan GitLab for all .prompt files under prompts_path, load and parse each,
566+
and return the mapping of repo file path -> JSON-like dict.
567+
"""
568+
ids = self.template_manager.list_templates(recursive=recursive) # IDs relative to prompts_path
569+
for pid in ids:
570+
# Ensure template is loaded into TemplateManager
571+
if pid not in self.template_manager.prompts:
572+
self.template_manager._load_prompt_from_gitlab(pid)
573+
574+
tmpl = self.template_manager.get_template(pid)
575+
if tmpl is None:
576+
# If something raced/failed, try once more
577+
self.template_manager._load_prompt_from_gitlab(pid)
578+
tmpl = self.template_manager.get_template(pid)
579+
if tmpl is None:
580+
continue
581+
582+
file_path = self.template_manager._id_to_repo_path(pid) # "prompts/chat/..../file.prompt"
583+
entry = self._template_to_json(pid, tmpl)
584+
585+
self._by_file[file_path] = entry
586+
# prefixed_id = pid if pid.startswith("gitlab::") else f"gitlab::{pid}"
587+
encoded_id = encode_prompt_id(pid)
588+
self._by_id[encoded_id] = entry
589+
# self._by_id[pid] = entry
590+
591+
return self._by_id
592+
593+
def reload(self, *, recursive: bool = True) -> Dict[str, Dict[str, Any]]:
594+
"""Clear the cache and re-load from GitLab."""
595+
self._by_file.clear()
596+
self._by_id.clear()
597+
return self.load_all(recursive=recursive)
598+
599+
def list_files(self) -> List[str]:
600+
"""Return the repo file paths currently cached."""
601+
return list(self._by_file.keys())
602+
603+
def list_ids(self) -> List[str]:
604+
"""Return the template IDs (relative to prompts_path, without extension) currently cached."""
605+
return list(self._by_id.keys())
606+
607+
def get_by_file(self, file_path: str) -> Optional[Dict[str, Any]]:
608+
"""Get a cached prompt JSON by repo file path."""
609+
return self._by_file.get(file_path)
610+
611+
def get_by_id(self, prompt_id: str) -> Optional[Dict[str, Any]]:
612+
"""Get a cached prompt JSON by prompt ID (relative to prompts_path)."""
613+
if prompt_id in self._by_id:
614+
return self._by_id[prompt_id]
615+
616+
# Try normalized forms
617+
decoded = decode_prompt_id(prompt_id)
618+
encoded = encode_prompt_id(decoded)
619+
620+
return self._by_id.get(encoded) or self._by_id.get(decoded)
621+
622+
# -------------------------
623+
# Internals
624+
# -------------------------
625+
626+
def _template_to_json(self, prompt_id: str, tmpl: GitLabPromptTemplate) -> Dict[str, Any]:
627+
"""
628+
Normalize a GitLabPromptTemplate into a JSON-like dict that is easy to serialize.
629+
"""
630+
# Safer copy of metadata (avoid accidental mutation)
631+
md = dict(tmpl.metadata or {})
632+
633+
# Pull standard fields (also present in metadata sometimes)
634+
model = tmpl.model
635+
temperature = tmpl.temperature
636+
max_tokens = tmpl.max_tokens
637+
optional_params = dict(tmpl.optional_params or {})
638+
639+
return {
640+
"id": prompt_id, # e.g. "greet/hi"
641+
"path": self.template_manager._id_to_repo_path(prompt_id), # e.g. "prompts/chat/greet/hi.prompt"
642+
"content": tmpl.content, # rendered content (without frontmatter)
643+
"metadata": md, # parsed frontmatter
644+
"model": model,
645+
"temperature": temperature,
646+
"max_tokens": max_tokens,
647+
"optional_params": optional_params,
648+
}

tests/test_litellm/integrations/gitlab/test_gitlab_integration.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import os
22
import sys
3-
from unittest.mock import MagicMock, patch
3+
44

55
import pytest
66

7-
sys.path.insert(0, os.path.abspath("../../.."))
7+
sys.path.insert(
8+
0, os.path.abspath("../../..")
9+
) # Adds the parent directory to the system path
810

11+
from unittest.mock import MagicMock, patch
912
from litellm.integrations.gitlab.gitlab_prompt_manager import GitLabPromptManager
1013

1114

@@ -84,8 +87,9 @@ def test_gitlab_prompt_manager_error_handling_load(mock_client_class):
8487

8588
config = {"project": "g/s/r", "access_token": "tkn"}
8689

87-
with pytest.raises(Exception, match="Failed to load prompt 'oops' from GitLab"):
88-
GitLabPromptManager(config, prompt_id="oops").prompt_manager # triggers load
90+
with pytest.raises(Exception, match="Failed to load prompt 'gitlab::oops' from GitLab"):
91+
GitLabPromptManager(config, prompt_id="oops").prompt_manager
92+
8993

9094

9195
def test_gitlab_prompt_manager_config_validation_via_client_ctor():
@@ -257,7 +261,7 @@ def test_gitlab_prompt_manager_list_templates_with_prompts_path(mock_client_clas
257261
# list_templates strips folder prefix + extension
258262
ids = manager.get_available_prompts()
259263
assert "a" in ids
260-
assert "sub/b" in ids
264+
assert "gitlab::sub::b" in ids
261265
assert all(not x.endswith(".prompt") for x in ids)
262266
assert all("/prompts/chat/" not in x for x in ids)
263267

@@ -284,8 +288,8 @@ def test_gitlab_template_manager_load_all_prompts(mock_client_class):
284288

285289
pm = GitLabPromptManager(config).prompt_manager
286290
loaded = pm.load_all_prompts()
287-
assert set(loaded) == {"a", "sub/b"}
288-
assert "a" in pm.prompts and "sub/b" in pm.prompts
291+
assert set(loaded) == {"gitlab::a", "gitlab::sub::b"}
292+
assert "gitlab::a" in pm.prompts and "gitlab::sub::b" in pm.prompts
289293

290294

291295
# -----------------------------
@@ -452,4 +456,4 @@ def test_gitlab_prompt_version_with_prompts_path(mock_client_class):
452456
# Path should include prompts_path and end with .prompt
453457
mock_client.get_file_content.assert_any_call(
454458
"prompts/chat/folder/sub/my_prompt.prompt", ref="commit-sha-999"
455-
)
459+
)

0 commit comments

Comments
 (0)