diff --git a/src/linux_mcp_server/commands.py b/src/linux_mcp_server/commands.py index 261d5a5c..3b2327a6 100644 --- a/src/linux_mcp_server/commands.py +++ b/src/linux_mcp_server/commands.py @@ -91,7 +91,9 @@ class CommandGroup(BaseModel): # === Services === "list_services": CommandGroup( commands={ - "default": CommandSpec(args=("systemctl", "list-units", "--type=service", "--all", "--no-pager")), + "default": CommandSpec( + args=("systemctl", "list-units", "--type=service", "--all", "--no-pager", "--output=json") + ), } ), "running_services": CommandGroup( @@ -103,12 +105,14 @@ class CommandGroup(BaseModel): ), "service_status": CommandGroup( commands={ - "default": CommandSpec(args=("systemctl", "status", "{service_name}", "--no-pager", "--full")), + "default": CommandSpec(args=("systemctl", "show", "{service_name}", "--no-pager")), } ), "service_logs": CommandGroup( commands={ - "default": CommandSpec(args=("journalctl", "-u", "{service_name}", "-n", "{lines}", "--no-pager")), + "default": CommandSpec( + args=("journalctl", "-u", "{service_name}", "-n", "{lines}", "--no-pager", "--output=json") + ), } ), # === Network === diff --git a/src/linux_mcp_server/parsers.py b/src/linux_mcp_server/parsers.py index 2db8607c..1a1b0f50 100644 --- a/src/linux_mcp_server/parsers.py +++ b/src/linux_mcp_server/parsers.py @@ -457,6 +457,23 @@ def parse_service_count(stdout: str) -> int: return count +def parse_systemctl_show(stdout: str) -> dict[str, str]: + """Parse systemctl show output into key-value pairs. + + Args: + stdout: Raw output from systemctl show command. + + Returns: + Dictionary of key-value pairs. + """ + result: dict[str, str] = {} + for line in stdout.strip().split("\n"): + if "=" in line: + key, value = line.split("=", 1) + result[key.strip()] = value.strip() + return result + + def parse_directory_listing( stdout: str, sort_by: str, diff --git a/src/linux_mcp_server/tools/services.py b/src/linux_mcp_server/tools/services.py index 50b4d11a..f7d297f1 100644 --- a/src/linux_mcp_server/tools/services.py +++ b/src/linux_mcp_server/tools/services.py @@ -1,16 +1,15 @@ """Service management tools.""" +import json import typing as t +from fastmcp.exceptions import ToolError from mcp.types import ToolAnnotations from pydantic import Field from linux_mcp_server.audit import log_tool_call from linux_mcp_server.commands import get_command -from linux_mcp_server.formatters import format_service_logs -from linux_mcp_server.formatters import format_service_status -from linux_mcp_server.formatters import format_services_list -from linux_mcp_server.parsers import parse_service_count +from linux_mcp_server.parsers import parse_systemctl_show from linux_mcp_server.server import mcp from linux_mcp_server.utils.decorators import disallow_local_execution_in_containers from linux_mcp_server.utils.types import Host @@ -27,27 +26,25 @@ @disallow_local_execution_in_containers async def list_services( host: Host = None, -) -> str: +) -> list[dict[str, str]]: """List all systemd services. Retrieves all systemd service units with their load state, active state, - sub-state, and description. Also includes a count of currently running services. + sub-state, and description. + + Returns: + list[dict[str, str]]: A list of dictionaries containing information about each service. + + Raises: + ToolError: If an error occurs while listing services. """ cmd = get_command("list_services") returncode, stdout, stderr = await cmd.run(host=host) if returncode != 0: - return f"Error listing services: {stderr}" - - # Get running services count - running_cmd = get_command("running_services") - returncode_summary, stdout_summary, _ = await running_cmd.run(host=host) + raise ToolError(f"Error listing services: {stderr}") - running_count = None - if returncode_summary == 0: - running_count = parse_service_count(stdout_summary) - - return format_services_list(stdout, running_count) + return t.cast(list[dict[str, str]], json.loads(stdout)) @mcp.tool( @@ -67,27 +64,34 @@ async def get_service_status( ), ], host: Host = None, -) -> str: +) -> dict[str, str]: """Get status of a specific systemd service. Retrieves detailed service information including active/enabled state, - main PID, memory usage, CPU time, and recent log entries from the journal. + main PID, memory usage, etc. + + Returns: + A dictionary containing the service status information. + + Raises: + ToolError: If there was an error getting the service status. """ # Ensure service name has .service suffix if not present if not service_name.endswith(".service") and "." not in service_name: service_name = f"{service_name}.service" cmd = get_command("service_status") - _, stdout, stderr = await cmd.run(host=host, service_name=service_name) + returncode, stdout, stderr = await cmd.run(host=host, service_name=service_name) - # Note: systemctl status returns non-zero for inactive services, but that's expected - if not stdout and stderr: - # Service not found - if "not found" in stderr.lower() or "could not be found" in stderr.lower(): - return f"Service '{service_name}' not found on this system." - return f"Error getting service status: {stderr}" + if returncode != 0: + raise ToolError(f"Error getting service status: {stderr}") + + status = parse_systemctl_show(stdout) + + if status.get("LoadState") == "not-found": + raise ToolError(f"Service '{service_name}' not found on this system.") - return format_service_status(stdout, service_name) + return status @mcp.tool( @@ -108,11 +112,17 @@ async def get_service_logs( ], lines: t.Annotated[int, Field(description="Number of log lines to retrieve.", ge=1, le=10_000)] = 50, host: Host = None, -) -> str: +) -> list[dict[str, str]]: """Get recent logs for a specific systemd service. Retrieves journal entries for the specified service unit, including timestamps, priority levels, and log messages. + + Raises: + ToolError: If an error occurs while retrieving logs. + + Returns: + list[dict[str, str]]: A list of dictionaries containing log entries. """ # Ensure service name has .service suffix if not present if not service_name.endswith(".service") and "." not in service_name: @@ -122,11 +132,9 @@ async def get_service_logs( returncode, stdout, stderr = await cmd.run(host=host, service_name=service_name, lines=lines) if returncode != 0: - if "not found" in stderr.lower() or "no entries" in stderr.lower(): - return f"No logs found for service '{service_name}'. The service may not exist or has no log entries." - return f"Error getting service logs: {stderr}" + raise ToolError(f"Error getting service logs: {stderr}") if is_empty_output(stdout): - return f"No log entries found for service '{service_name}'." + raise ToolError(f"No log entries found for service '{service_name}'.") - return format_service_logs(stdout, service_name, lines) + return t.cast(list[dict[str, str]], json.loads(stdout)) diff --git a/tests/parsers/test_parse_systemctl_show.py b/tests/parsers/test_parse_systemctl_show.py new file mode 100644 index 00000000..e63134e9 --- /dev/null +++ b/tests/parsers/test_parse_systemctl_show.py @@ -0,0 +1,33 @@ +"""Tests for parse_systemctl_show""" + +import pytest + +from linux_mcp_server.parsers import parse_systemctl_show + + +@pytest.mark.parametrize( + "stdout, expected", + [ + ( + """ + ActiveState=active + SubState=running + LoadState=loaded + """, + { + "ActiveState": "active", + "SubState": "running", + "LoadState": "loaded", + }, + ), + ( + """ + Field=value + EmptyField= + """, + {"Field": "value", "EmptyField": ""}, + ), + ], +) +def test_parse_systemctl_show(stdout, expected): + assert parse_systemctl_show(stdout) == expected diff --git a/tests/tools/test_services.py b/tests/tools/test_services.py index 35ffd086..c668b7d5 100644 --- a/tests/tools/test_services.py +++ b/tests/tools/test_services.py @@ -4,6 +4,8 @@ import pytest +from fastmcp.exceptions import ToolError + @pytest.fixture def mock_execute_with_fallback(mock_execute_with_fallback_for): @@ -23,54 +25,69 @@ async def test_list_services(self, mcp_client): assert all(any(n in result_text for n in case) for case in expected), "Did not find all expected values" - @pytest.mark.parametrize( - "service_name, expected", - ( - ("sshd.service", ("active", "inactive", "loaded", "not found")), - ("nonexistent-service-xyz123", ("not found", "could not", "error")), - ), - ) - async def test_get_service_status(self, mcp_client, service_name, expected): - result = await mcp_client.call_tool("get_service_status", arguments={"service_name": service_name}) - result_text = result.content[0].text.casefold() + async def test_get_service_status(self, mcp_client): + """Test getting service status returns structured data.""" + result = await mcp_client.call_tool("get_service_status", arguments={"service_name": "sshd.service"}) + + # The tool returns structured data (dict), so check the structured content + assert result.structured_content is not None, "Expected structured data" + + # Verify we have service status fields + data = result.structured_content + assert "LoadState" in data or "ActiveState" in data, "Expected service status fields in structured data" - assert any(n in result_text for n in expected), "Did not find any expected values" + # ActiveState should be one of the valid states + if "ActiveState" in data: + assert data["ActiveState"] in ("active", "inactive", "activating", "deactivating", "failed"), ( + f"Unexpected ActiveState: {data['ActiveState']}" + ) async def test_get_service_status_with_nonexistent_service(self, mcp_client): - result = await mcp_client.call_tool( - "get_service_status", arguments={"service_name": "nonexistent-service-xyz123"} - ) - result_text = result.content[0].text.casefold() - expected = ( - "not found", - "could not", - "error", - ) - assert any(n in result_text for n in expected), "Did not find any expected values" + """Test that nonexistent service raises ToolError.""" + + with pytest.raises(ToolError, match="Service 'nonexistent-service-xyz123.service' not found on this system."): + await mcp_client.call_tool("get_service_status", arguments={"service_name": "nonexistent-service-xyz123"}) + + async def test_get_service_status_error(self, mock_execute_with_fallback, mcp_client): + """Test that get_service_status raises ToolError when systemctl fails.""" + mock_execute_with_fallback.return_value = (1, "", "Failed to get unit file state: Connection refused") + + with pytest.raises(ToolError, match="Error getting service status: Failed to get unit file state"): + await mcp_client.call_tool("get_service_status", arguments={"service_name": "sshd"}) + + async def test_get_service_logs(self, mock_execute_with_fallback, mcp_client): + """Test getting service logs with mocked output.""" + mock_output = '[{"__REALTIME_TIMESTAMP": "1600000000000000", "MESSAGE": "sshd: session opened for user test", "PRIORITY": "6", "_PID": "1234"}, {"__REALTIME_TIMESTAMP": "1600000001000000", "MESSAGE": "sshd: session closed for user test", "PRIORITY": "6", "_PID": "1234"}]' + mock_execute_with_fallback.return_value = (0, mock_output, "") - async def test_get_service_logs(self, mcp_client): result = await mcp_client.call_tool("get_service_logs", arguments={"service_name": "sshd.service", "lines": 5}) - # Filter out empty lines, header lines (=), and journalctl boot markers (--) - result_lines = [ - line - for line in result.content[0].text.split("\n") - if line and not line.startswith("=") and not line.startswith("--") - ] - assert len(result_lines) <= 5, "Got more lines than expected" + assert result.structured_content is not None, "Expected structured data" + logs = result.structured_content["result"] + assert isinstance(logs, list) + assert len(logs) == 2 + assert logs[0]["MESSAGE"] == "sshd: session opened for user test" + mock_execute_with_fallback.assert_called() async def test_get_service_logs_with_nonexistent_service(self, mcp_client): - result = await mcp_client.call_tool( - "get_service_logs", arguments={"service_name": "nonexistent-service-xyz123", "lines": 10} - ) - result_text = result.content[0].text.casefold() - expected = ( - "not found", - "no entries", - "error", - ) + with pytest.raises(ToolError, match="No log entries found for service 'nonexistent-service-xyz123.service'."): + await mcp_client.call_tool( + "get_service_logs", arguments={"service_name": "nonexistent-service-xyz123", "lines": 10} + ) + + async def test_get_service_logs_error(self, mock_execute_with_fallback, mcp_client): + """Test that get_service_logs raises ToolError when journalctl fails.""" + mock_execute_with_fallback.return_value = (1, "", "Failed to access journal: Permission denied") - assert any(n in result_text for n in expected), "Did not find any expected values" + with pytest.raises(ToolError, match="Error getting service logs: Failed to access journal"): + await mcp_client.call_tool("get_service_logs", arguments={"service_name": "sshd", "lines": 10}) + + async def test_list_services_error(self, mock_execute_with_fallback, mcp_client): + """Test that list_services raises ToolError when systemctl fails.""" + mock_execute_with_fallback.return_value = (1, "", "Failed to connect to bus: No such file or directory") + + with pytest.raises(ToolError, match="Error listing services: Failed to connect to bus"): + await mcp_client.call_tool("list_services") class TestRemoteServices: @@ -78,40 +95,54 @@ class TestRemoteServices: async def test_list_services_remote(self, mock_execute_with_fallback, mcp_client): """Test listing services on a remote host.""" - mock_output = "UNIT LOAD ACTIVE SUB DESCRIPTION\nnginx.service loaded active running Nginx server\n" + mock_output = '[{"unit":"nginx.service","load":"loaded","active":"active","sub":"running","description":"Nginx HTTP Server"}]' mock_execute_with_fallback.return_value = (0, mock_output, "") result = await mcp_client.call_tool("list_services", arguments={"host": "remote.example.com"}) - result_text = result.content[0].text.casefold() - assert "nginx.service" in result_text - assert "system services" in result_text + assert result.structured_content is not None + services = result.structured_content["result"] + assert isinstance(services, list) + assert services[0]["unit"] == "nginx.service" + assert services[0]["load"] == "loaded" + assert services[0]["active"] == "active" + assert services[0]["sub"] == "running" + assert services[0]["description"] == "Nginx HTTP Server" mock_execute_with_fallback.assert_called() async def test_get_service_status_remote(self, mock_execute_with_fallback, mcp_client): """Test getting service status on a remote host.""" - mock_output = "● nginx.service - Nginx HTTP Server\n Loaded: loaded\n Active: active (running)" + # Mock systemctl show output (key=value format) + mock_output = "LoadState=loaded\nActiveState=active\nSubState=running\nDescription=Nginx HTTP Server" mock_execute_with_fallback.return_value = (0, mock_output, "") result = await mcp_client.call_tool( "get_service_status", arguments={"service_name": "nginx", "host": "remote.example.com"} ) - result_text = result.content[0].text.casefold() - assert "nginx.service" in result_text - assert "active" in result_text + # The tool now returns structured data + assert result.structured_content is not None + data = result.structured_content + + assert data.get("LoadState") == "loaded" + assert data.get("ActiveState") == "active" + assert data.get("SubState") == "running" mock_execute_with_fallback.assert_called() async def test_get_service_logs_remote(self, mock_execute_with_fallback, mcp_client): """Test getting service logs on a remote host.""" - mock_output = "Jan 01 12:00:00 host nginx[1234]: Starting Nginx\nJan 01 12:00:01 host nginx[1234]: Started" + mock_output = '[{"_REALTIME_TIMESTAMP": "Jan 01 12:00:00", "_PID": "1234", "MESSAGE": "Starting Nginx"}, {"_REALTIME_TIMESTAMP": "Jan 01 12:00:01", "_PID": "1234", "MESSAGE": "Started"}]' mock_execute_with_fallback.return_value = (0, mock_output, "") result = await mcp_client.call_tool( "get_service_logs", arguments={"service_name": "nginx", "host": "remote.example.com", "lines": 50} ) - result_text = result.content[0].text.casefold() - assert "nginx" in result_text - assert "starting" in result_text + assert result.structured_content is not None + data = result.structured_content["result"] + + assert isinstance(data, list) + assert len(data) == 2 + assert data[0]["MESSAGE"] == "Starting Nginx" + assert data[1]["MESSAGE"] == "Started" mock_execute_with_fallback.assert_called()