Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions iam-policy-autopilot-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,15 @@ for direct integration with IDEs and tools. 'http' starts an HTTP server for net
long_help = "Port number to bind the HTTP server to when using HTTP transport. \
Only used when --transport=http. The server will bind to 127.0.0.1 (localhost) on the specified port.")]
port: u16,

/// Run in read-only mode (no policy modifications will be applied to any AWS account)
#[arg(
long = "read-only",
long_help = "When enabled, the MCP server will operate in read-only mode and will not make any \
modifications to AWS resources. Tools that would normally apply changes will instead return the generated \
output without executing any mutations."
)]
read_only: bool,
},
}

Expand Down Expand Up @@ -602,8 +611,12 @@ async fn main() {
}
}

Commands::McpServer { transport, port } => {
match start_mcp_server(transport, port).await {
Commands::McpServer {
transport,
port,
read_only,
} => {
match start_mcp_server(transport, port, read_only).await {
Ok(()) => ExitCode::Success,
Err(e) => {
print_cli_command_error(e);
Expand Down
6 changes: 3 additions & 3 deletions iam-policy-autopilot-mcp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl Display for McpTransport {
}
}

pub async fn start_mcp_server(transport: McpTransport, port: u16) -> Result<()> {
pub async fn start_mcp_server(transport: McpTransport, port: u16, read_only: bool) -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing individual booleans, let's introduce a configuration object.

info!("Starting MCP server with transport: {}", transport);

let env = env_logger::Env::default().filter_or("IAMPA_LOG_LEVEL", "debug");
Expand Down Expand Up @@ -65,14 +65,14 @@ pub async fn start_mcp_server(transport: McpTransport, port: u16) -> Result<()>
let bind_address: String = format!("{}:{}", BIND_ADDRESS, port);
info!("Starting HTTP MCP server at {}", bind_address);

crate::mcp::begin_http_transport(bind_address.as_str(), path_str)
crate::mcp::begin_http_transport(bind_address.as_str(), path_str, read_only)
.await
.with_context(|| format!("Failed to start HTTP Server at '{bind_address}'"))?
}
McpTransport::Stdio => {
info!("Starting STDIO MCP server");

crate::mcp::begin_stdio_transport(path_str)
crate::mcp::begin_stdio_transport(path_str, read_only)
.await
.with_context(|| "Failed to start STDIO Server".to_string())?
}
Expand Down
60 changes: 53 additions & 7 deletions iam-policy-autopilot-mcp-server/src/mcp.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use anyhow;
use log::{error, info, trace};
use rmcp::{
handler::server::tool::ToolCallContext,
handler::server::{tool::ToolRouter, wrapper::Parameters},
model::{ErrorCode, ServerCapabilities, ServerInfo},
model::{
CallToolRequestParam, CallToolResult, ErrorCode, ListToolsResult, PaginatedRequestParam,
ServerCapabilities, ServerInfo, Tool,
},
service::RequestContext,
tool, tool_handler, tool_router,
tool, tool_router,
transport::{
self, streamable_http_server::session::local::LocalSessionManager, StreamableHttpService,
},
Expand All @@ -22,14 +26,16 @@ use crate::tools::{
struct IamAutoPilotMcpServer {
tool_router: ToolRouter<Self>,
log_file: Option<String>,
read_only: bool,
}

#[tool_router]
impl IamAutoPilotMcpServer {
pub fn new(log_file: Option<String>) -> Self {
pub fn new(log_file: Option<String>, read_only: bool) -> Self {
Self {
tool_router: Self::tool_router(),
log_file,
read_only,
}
}

Expand Down Expand Up @@ -137,7 +143,6 @@ impl IamAutoPilotMcpServer {
}
}

#[tool_handler]
impl ServerHandler for IamAutoPilotMcpServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
Expand Down Expand Up @@ -166,14 +171,52 @@ impl ServerHandler for IamAutoPilotMcpServer {
..Default::default()
}
}

async fn list_tools(
&self,
_: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListToolsResult, McpError> {
let all_tools = self.tool_router.list_all();

// Filter out fix_access_denied tool when in read-only mode
let tools: Vec<Tool> = if self.read_only {
all_tools
.into_iter()
.filter(|tool| tool.name.as_ref() != "fix_access_denied")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of filtering by name, would it make sense for the tools to have a property against we filter?

.collect()
} else {
all_tools
};

Ok(ListToolsResult {
tools,
next_cursor: None,
})
}

async fn call_tool(
&self,
request: CallToolRequestParam,
context: RequestContext<RoleServer>,
) -> Result<CallToolResult, McpError> {
let tool_context = ToolCallContext {
service: self,
name: request.name.clone(),
arguments: request.arguments,
request_context: context,
};
self.tool_router.call(tool_context).await
}
}

pub async fn begin_http_transport(
bind_address: &str,
log_file: Option<String>,
read_only: bool,
) -> anyhow::Result<()> {
let service = StreamableHttpService::new(
move || Ok(IamAutoPilotMcpServer::new(log_file.clone())),
move || Ok(IamAutoPilotMcpServer::new(log_file.clone(), read_only)),
LocalSessionManager::default().into(),
Default::default(),
);
Expand Down Expand Up @@ -205,8 +248,11 @@ pub async fn begin_http_transport(
Ok(())
}

pub async fn begin_stdio_transport(log_file: Option<String>) -> anyhow::Result<()> {
let server = IamAutoPilotMcpServer::new(log_file);
pub async fn begin_stdio_transport(
log_file: Option<String>,
read_only: bool,
) -> anyhow::Result<()> {
let server = IamAutoPilotMcpServer::new(log_file, read_only);
let service = server.serve(transport::stdio()).await?;
service.waiting().await?;
Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ use tokio::net::TcpStream;
use tokio::process::{Child, Command};
use tokio::time::{sleep, Duration};

async fn setup_stdio() -> RunningService<RoleClient, ()> {
async fn setup_stdio(flags: Vec<&str>) -> RunningService<RoleClient, ()> {
// Create MCP client using TokioChildProcess with debug binary
let mut command = Command::new("../target/debug/iam-policy-autopilot");
command.args(&["mcp-server"]);
let mut args = vec!["mcp-server"];
args.extend(flags);
command.args(&args);

().serve(
TokioChildProcess::new(command)
Expand Down Expand Up @@ -96,7 +98,7 @@ async fn setup_http() -> (RunningService<RoleClient, InitializeRequestParam>, Ch

#[tokio::test]
async fn test_stdio_list_tools() {
let client = setup_stdio().await;
let client = setup_stdio(vec![]).await;

// Call list_tools to get available tools
let tools_result = client.list_tools(None).await.unwrap();
Expand Down Expand Up @@ -130,7 +132,7 @@ async fn test_stdio_generate_policy() {
.unwrap()
.join(Path::new("tests/test_data/lambda.py"));

let client = setup_stdio().await;
let client = setup_stdio(vec![]).await;
let tool_result = client
.call_tool(CallToolRequestParam {
name: "generate_application_policies".into(),
Expand All @@ -150,7 +152,7 @@ async fn test_stdio_generate_policy() {

#[tokio::test]
async fn test_stdio_generate_policy_for_access_denied() {
let client = setup_stdio().await;
let client = setup_stdio(vec![]).await;
let tool_result = client
.call_tool(CallToolRequestParam {
name: "generate_policy_for_access_denied".into(),
Expand Down Expand Up @@ -252,3 +254,28 @@ async fn test_http_generate_policy_for_access_denied() {
let _ = server_process.start_kill();
let _ = server_process.wait().await;
}

#[tokio::test]
async fn test_stdio_read_only_mode_hides_fix_access_denied_tool() {
let client = setup_stdio(vec!["--read-only"]).await;

// Call list_tools to get available tools
let tools_result = client.list_tools(None).await.unwrap();

// In read-only mode, we should only have 2 tools (fix_access_denied should be hidden)
assert_eq!(
tools_result.tools.len(),
2,
"Expected 2 tools in read-only mode, got {}",
tools_result.tools.len()
);

// Check that fix_access_denied is NOT present
let tool_names: Vec<&str> = tools_result.tools.iter().map(|t| t.name.as_ref()).collect();
assert!(tool_names.contains(&"generate_application_policies"));
assert!(tool_names.contains(&"generate_policy_for_access_denied"));
assert!(
!tool_names.contains(&"fix_access_denied"),
"fix_access_denied tool should be hidden in read-only mode"
);
}