diff --git a/README.md b/README.md index 5bc7683..e14ef1e 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,30 @@ Here is an example of how to do this using SSE, which is a deprecated http-based export $(grep -v '^#' ~/.rh-jira-mcp.env | xargs) && python server.py --transport sse --port 3075 ``` +For either Streamable HTTP or SSE, the JIRA_API_KEY in your environment is ignored (and not neeed). This is an important security feature, because otherwise anyone who +had access to the HTTP service would have access to the account information for whoever configured and ran that server. +Instead, calling applications must send in their own Jira token as a Bearer token. Here is an example of how to do that using Llama Stack: + +```python +from llama_stack_client import LlamaStackClient + +client = LlamaStackClient(base_url=LLAMA_STACK_URL) +mcp_llama_stack_client_response = client.responses.create( + model=LLAMA_STACK_MODEL_ID, + input="Tell me about RHAISTRAT-24.", + tools=[ + { + "type": "mcp", + "server_url": JIRA_MCP_URL, + "server_label": "Jira_tools", + "headers": { + "Authorization": f"Bearer {JIRA_API_TOKEN}" + } + } + ] +) +``` + ## Available Tools This MCP server provides the following tools: diff --git a/server.py b/server.py index 896843d..7c1dcaa 100755 --- a/server.py +++ b/server.py @@ -5,9 +5,9 @@ from dotenv import load_dotenv from jira import JIRA from fastmcp import FastMCP +from fastmcp.server.dependencies import get_http_headers from fastapi import HTTPException import json -import logging ## Custom fields IDs QA_CONTACT_FID = "customfield_12315948" @@ -20,12 +20,36 @@ JIRA_ENABLE_WRITE_OPERATIONS_STRING = os.getenv("JIRA_ENABLE_WRITE", "false") ENABLE_WRITE = JIRA_ENABLE_WRITE_OPERATIONS_STRING.lower() == "true" -if not all([JIRA_URL, JIRA_API_TOKEN]): - raise RuntimeError("Missing JIRA_URL or JIRA_API_TOKEN environment variables") +jira_client = JIRA(server=JIRA_URL, token_auth=JIRA_API_TOKEN) # ─── 2. Create a Jira client ─────────────────────────────────────────────────── # Uses token_auth (API token) for authentication. -jira_client = JIRA(server=JIRA_URL, token_auth=JIRA_API_TOKEN) + + +def get_jira_client(headers: dict[str, str]): + """ + Get a JIRA client instance. + + If a global jira_client exists (stdio mode), use it. + Otherwise, create a new client using the authorization header (server mode). + """ + global jira_client + + # If we have a global client (stdio mode with env token), use it + if jira_client is not None: + return jira_client + + # Server mode: extract token from authorization header + auth_header = headers.get("authorization", headers.get("Authorization")) + if auth_header: + parts = auth_header.split(" ") + if len(parts) != 2: + raise RuntimeError("Invalid Authorization header format. Expected: 'Bearer '") + token = parts[1] + return JIRA(server=JIRA_URL, token_auth=token) + + raise RuntimeError("No access token available. Provide Authorization header with Bearer token.") + # ─── 3. Instantiate the MCP server ───────────────────────────────────────────── mcp = FastMCP("Jira Context Server") @@ -35,11 +59,11 @@ @mcp.tool() def get_jira(issue_key: str) -> str: """ - Fetch the Jira issue identified by 'issue_key' using jira_client, - then return a Markdown string: "# ISSUE-KEY: summary\n\ndescription" + Fetch the Jira issue identified by 'issue_key' then + return a Markdown string: "# ISSUE-KEY: summary\n\ndescription" """ try: - issue = jira_client.issue(issue_key) + issue = get_jira_client(get_http_headers()).issue(issue_key) except Exception as e: # If the JIRA client raises an error (e.g. issue not found), # wrap it in an HTTPException so MCP/Client sees a 4xx/5xx. @@ -67,7 +91,7 @@ def to_markdown(obj): def search_issues(jql: str, max_results: int = 100) -> str: """Search issues using JQL.""" try: - issues = jira_client.search_issues(jql, maxResults=max_results) + issues = get_jira_client(get_http_headers()).search_issues(jql, maxResults=max_results) # Extract only essential fields to avoid token limit issues simplified_issues = [] for issue in issues: @@ -101,7 +125,7 @@ def search_issues(jql: str, max_results: int = 100) -> str: def search_users(query: str, max_results: int = 10) -> str: """Search users by query.""" try: - users = jira_client.search_users(query, maxResults=max_results) + users = get_jira_client(get_http_headers()).search_users(query, maxResults=max_results) return to_markdown([u.raw for u in users]) except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to search users: {e}") @@ -111,7 +135,7 @@ def search_users(query: str, max_results: int = 10) -> str: def list_projects() -> str: """List all projects.""" try: - projects = jira_client.projects() + projects = get_jira_client(get_http_headers()).projects() return to_markdown([p.raw for p in projects]) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch projects: {e}") @@ -121,7 +145,7 @@ def list_projects() -> str: def get_project(project_key: str) -> str: """Get a project by key.""" try: - project = jira_client.project(project_key) + project = get_jira_client(get_http_headers()).project(project_key) return to_markdown(project) except Exception as e: raise HTTPException(status_code=404, detail=f"Failed to fetch project: {e}") @@ -131,7 +155,7 @@ def get_project(project_key: str) -> str: def get_project_components(project_key: str) -> str: """Get components for a project.""" try: - components = jira_client.project_components(project_key) + components = get_jira_client(get_http_headers()).project_components(project_key) return to_markdown([c.raw for c in components]) except Exception as e: raise HTTPException(status_code=404, detail=f"Failed to fetch components: {e}") @@ -141,7 +165,7 @@ def get_project_components(project_key: str) -> str: def get_project_versions(project_key: str) -> str: """Get versions for a project.""" try: - versions = jira_client.project_versions(project_key) + versions = get_jira_client(get_http_headers()).project_versions(project_key) return to_markdown([v.raw for v in versions]) except Exception as e: raise HTTPException(status_code=404, detail=f"Failed to fetch versions: {e}") @@ -151,7 +175,7 @@ def get_project_versions(project_key: str) -> str: def get_project_roles(project_key: str) -> str: """Get roles for a project.""" try: - roles = jira_client.project_roles(project_key) + roles = get_jira_client(get_http_headers()).project_roles(project_key) return to_markdown(roles) except Exception as e: raise HTTPException(status_code=404, detail=f"Failed to fetch roles: {e}") @@ -161,7 +185,7 @@ def get_project_roles(project_key: str) -> str: def get_project_permission_scheme(project_key: str) -> str: """Get permission scheme for a project.""" try: - scheme = jira_client.project_permissionscheme(project_key) + scheme = get_jira_client(get_http_headers()).project_permissionscheme(project_key) return to_markdown(scheme.raw) except Exception as e: raise HTTPException(status_code=404, detail=f"Failed to fetch permission scheme: {e}") @@ -171,7 +195,7 @@ def get_project_permission_scheme(project_key: str) -> str: def get_project_issue_types(project_key: str) -> str: """Get issue types for a project.""" try: - types = jira_client.project_issue_types(project_key) + types = get_jira_client(get_http_headers()).project_issue_types(project_key) return to_markdown([t.raw for t in types]) except Exception as e: raise HTTPException(status_code=404, detail=f"Failed to fetch issue types: {e}") @@ -181,7 +205,7 @@ def get_project_issue_types(project_key: str) -> str: def get_current_user() -> str: """Get current user info.""" try: - user = jira_client.myself() + user = get_jira_client(get_http_headers()).myself() return to_markdown(user) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch current user: {e}") @@ -191,7 +215,7 @@ def get_current_user() -> str: def get_user(account_id: str) -> str: """Get user by account ID.""" try: - user = jira_client.user(account_id) + user = get_jira_client(get_http_headers()).user(account_id) return to_markdown(user.raw) except Exception as e: raise HTTPException(status_code=404, detail=f"Failed to fetch user: {e}") @@ -203,7 +227,7 @@ def get_assignable_users_for_project( ) -> str: """Get assignable users for a project.""" try: - users = jira_client.search_assignable_users_for_projects( + users = get_jira_client(get_http_headers()).search_assignable_users_for_projects( query, project_key, maxResults=max_results ) return to_markdown([u.raw for u in users]) @@ -215,7 +239,7 @@ def get_assignable_users_for_project( def get_assignable_users_for_issue(issue_key: str, query: str = "", max_results: int = 10) -> str: """Get assignable users for an issue.""" try: - users = jira_client.search_assignable_users_for_issues( + users = get_jira_client(get_http_headers()).search_assignable_users_for_issues( query, issueKey=issue_key, maxResults=max_results ) return to_markdown([u.raw for u in users]) @@ -227,7 +251,9 @@ def get_assignable_users_for_issue(issue_key: str, query: str = "", max_results: def list_boards(max_results: int = 10, project_key_or_id: str = None) -> str: """List boards, optionally filtered by project.""" try: - boards = jira_client.boards(maxResults=max_results, projectKeyOrID=project_key_or_id) + boards = get_jira_client(get_http_headers()).boards( + maxResults=max_results, projectKeyOrID=project_key_or_id + ) return to_markdown([b.raw for b in boards]) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch boards: {e}") @@ -237,7 +263,7 @@ def list_boards(max_results: int = 10, project_key_or_id: str = None) -> str: def list_sprints(board_id: int, max_results: int = 10) -> str: """List sprints for a board.""" try: - sprints = jira_client.sprints(board_id, maxResults=max_results) + sprints = get_jira_client(get_http_headers()).sprints(board_id, maxResults=max_results) return to_markdown([s.raw for s in sprints]) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch sprints: {e}") @@ -247,7 +273,7 @@ def list_sprints(board_id: int, max_results: int = 10) -> str: def get_sprint(sprint_id: int) -> str: """Get sprint by ID.""" try: - sprint = jira_client.sprint(sprint_id) + sprint = get_jira_client(get_http_headers()).sprint(sprint_id) return to_markdown(sprint.raw) except Exception as e: raise HTTPException(status_code=404, detail=f"Failed to fetch sprint: {e}") @@ -257,7 +283,7 @@ def get_sprint(sprint_id: int) -> str: def get_sprints_by_name(board_id: int, state: str = None) -> str: """Get sprints by name for a board, optionally filtered by state.""" try: - sprints = jira_client.sprints_by_name(board_id, state=state) + sprints = get_jira_client(get_http_headers()).sprints_by_name(board_id, state=state) return to_markdown(sprints) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch sprints by name: {e}") @@ -288,7 +314,7 @@ def create_issue( if assignee: issue_dict["assignee"] = {"name": assignee} - new_issue = jira_client.create_issue(fields=issue_dict) + new_issue = get_jira_client(get_http_headers()).create_issue(fields=issue_dict) return f"Created issue {new_issue.key}: {summary}" except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to create issue: {e}") @@ -304,7 +330,7 @@ def update_issue( ) -> str: """Update an existing Jira issue.""" try: - issue = jira_client.issue(issue_key) + issue = get_jira_client(get_http_headers()).issue(issue_key) update_dict = {} if summary: @@ -329,8 +355,8 @@ def update_issue( def add_comment(issue_key: str, comment_body: str) -> str: """Add a comment to a Jira issue.""" try: - issue = jira_client.issue(issue_key) - comment = jira_client.add_comment(issue, comment_body) + issue = get_jira_client(get_http_headers()).issue(issue_key) + comment = get_jira_client(get_http_headers()).add_comment(issue, comment_body) return f"Added comment to {issue_key}: {comment.id}" except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to add comment to {issue_key}: {e}") @@ -340,7 +366,7 @@ def add_comment(issue_key: str, comment_body: str) -> str: def delete_comment(issue_key: str, comment_id: str) -> str: """Delete a comment from a Jira issue.""" try: - comment = jira_client.comment(issue_key, comment_id) + comment = get_jira_client(get_http_headers()).comment(issue_key, comment_id) comment.delete() return f"Deleted comment {comment_id} from {issue_key}" except Exception as e: @@ -354,7 +380,7 @@ def delete_comment(issue_key: str, comment_id: str) -> str: def get_issue_comments(issue_key: str) -> str: """Get all comments for a Jira issue.""" try: - issue = jira_client.issue(issue_key) + issue = get_jira_client(get_http_headers()).issue(issue_key) comments = [] for comment in issue.fields.comment.comments: comment_data = { @@ -374,8 +400,8 @@ def get_issue_comments(issue_key: str) -> str: def assign_issue(issue_key: str, assignee: str) -> str: """Assign a Jira issue to a user.""" try: - issue = jira_client.issue(issue_key) - jira_client.assign_issue(issue, assignee) + issue = get_jira_client(get_http_headers()).issue(issue_key) + get_jira_client(get_http_headers()).assign_issue(issue, assignee) return f"Assigned issue {issue_key} to {assignee}" except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to assign issue {issue_key}: {e}") @@ -385,8 +411,8 @@ def assign_issue(issue_key: str, assignee: str) -> str: def unassign_issue(issue_key: str) -> str: """Unassign a Jira issue.""" try: - issue = jira_client.issue(issue_key) - jira_client.assign_issue(issue, None) + issue = get_jira_client(get_http_headers()).issue(issue_key) + get_jira_client(get_http_headers()).assign_issue(issue, None) return f"Unassigned issue {issue_key}" except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to unassign issue {issue_key}: {e}") @@ -396,8 +422,8 @@ def unassign_issue(issue_key: str) -> str: def transition_issue(issue_key: str, transition_name: str, comment: str = None) -> str: """Transition a Jira issue to a new status.""" try: - issue = jira_client.issue(issue_key) - transitions = jira_client.transitions(issue) + issue = get_jira_client(get_http_headers()).issue(issue_key) + transitions = get_jira_client(get_http_headers()).transitions(issue) # Find the transition by name transition_id = None @@ -412,10 +438,12 @@ def transition_issue(issue_key: str, transition_name: str, comment: str = None) # Perform the transition if comment: - jira_client.transition_issue(issue, transition_id, comment=comment) + get_jira_client(get_http_headers()).transition_issue( + issue, transition_id, comment=comment + ) return f"Transitioned issue {issue_key} to '{transition_name}' with comment" else: - jira_client.transition_issue(issue, transition_id) + get_jira_client(get_http_headers()).transition_issue(issue, transition_id) return f"Transitioned issue {issue_key} to '{transition_name}'" except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to transition issue {issue_key}: {e}") @@ -425,8 +453,8 @@ def transition_issue(issue_key: str, transition_name: str, comment: str = None) def get_issue_transitions(issue_key: str) -> str: """Get available transitions for a Jira issue.""" try: - issue = jira_client.issue(issue_key) - transitions = jira_client.transitions(issue) + issue = get_jira_client(get_http_headers()).issue(issue_key) + transitions = get_jira_client(get_http_headers()).transitions(issue) transition_list = [{"id": t["id"], "name": t["name"]} for t in transitions] return to_markdown(transition_list) except Exception as e: @@ -439,7 +467,7 @@ def get_issue_transitions(issue_key: str) -> str: def delete_issue(issue_key: str) -> str: """Delete a Jira issue (use with caution).""" try: - issue = jira_client.issue(issue_key) + issue = get_jira_client(get_http_headers()).issue(issue_key) issue.delete() return f"Deleted issue {issue_key}" except Exception as e: @@ -450,7 +478,7 @@ def delete_issue(issue_key: str) -> str: def add_issue_labels(issue_key: str, labels: list) -> str: """Add labels to a Jira issue.""" try: - issue = jira_client.issue(issue_key) + issue = get_jira_client(get_http_headers()).issue(issue_key) current_labels = list(issue.fields.labels) new_labels = list(set(current_labels + labels)) # Remove duplicates issue.update(fields={"labels": new_labels}) @@ -463,7 +491,7 @@ def add_issue_labels(issue_key: str, labels: list) -> str: def remove_issue_labels(issue_key: str, labels: list) -> str: """Remove labels from a Jira issue.""" try: - issue = jira_client.issue(issue_key) + issue = get_jira_client(get_http_headers()).issue(issue_key) current_labels = list(issue.fields.labels) new_labels = [label for label in current_labels if label not in labels] issue.update(fields={"labels": new_labels}) @@ -528,6 +556,15 @@ def parse_arguments(): args = parse_arguments() if args.transport == "stdio": + if not all([JIRA_URL, JIRA_API_TOKEN]): + raise RuntimeError("Missing JIRA_URL or JIRA_API_TOKEN environment variables") mcp.run(transport=args.transport) else: + if not JIRA_URL: + raise RuntimeError("Missing JIRA_URL environment variable") + # If running as a server, we use the access token from the request, not from the environment variable. + # This is more secure because each caller providers their own access token instead of having one + # shared token that everyone who can access the server can use. + JIRA_API_TOKEN = None + jira_client = None mcp.run(transport=args.transport, host=args.host, port=args.port) diff --git a/test_server.py b/test_server.py index 540493d..d0117f5 100644 --- a/test_server.py +++ b/test_server.py @@ -2,9 +2,8 @@ import pytest import os -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import patch, MagicMock from fastapi import HTTPException -import json # Set up required environment variables before importing server module os.environ["JIRA_URL"] = "https://test.example.com" @@ -95,7 +94,9 @@ def __init__(self, comment_id, body="Test comment", author_name="Test Author"): @pytest.fixture def mock_jira_client(): """Create a mock Jira client""" - with patch("server.jira_client") as mock_client: + with patch("server.get_jira_client") as mock_get_client: + mock_client = MagicMock() + mock_get_client.return_value = mock_client yield mock_client