Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 164 additions & 6 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::fmt::Debug;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
Expand Down Expand Up @@ -55,6 +57,7 @@ use mcp_types::ReadResourceResult;
use mcp_types::RequestId;
use serde_json;
use serde_json::Value;
use tokio::fs;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::sync::oneshot;
Expand Down Expand Up @@ -98,6 +101,9 @@ use crate::protocol::ReviewDecision;
use crate::protocol::SandboxCommandAssessment;
use crate::protocol::SandboxPolicy;
use crate::protocol::SessionConfiguredEvent;
use crate::protocol::SkillErrorInfo;
use crate::protocol::SkillInfo;
use crate::protocol::SkillLoadOutcomeInfo;
use crate::protocol::StreamErrorEvent;
use crate::protocol::Submission;
use crate::protocol::TokenCountEvent;
Expand All @@ -110,6 +116,9 @@ use crate::rollout::RolloutRecorderParams;
use crate::rollout::map_session_init_error;
use crate::shell;
use crate::shell_snapshot::ShellSnapshot;
use crate::skills::SkillLoadOutcome;
use crate::skills::SkillMetadata;
use crate::skills::load_skills;
use crate::state::ActiveTurn;
use crate::state::SessionServices;
use crate::state::SessionState;
Expand All @@ -126,6 +135,7 @@ use crate::tools::spec::ToolsConfigParams;
use crate::turn_diff_tracker::TurnDiffTracker;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_instructions::DeveloperInstructions;
use crate::user_instructions::SkillInstructions;
use crate::user_instructions::UserInstructions;
use crate::user_notification::UserNotification;
use crate::util::backoff;
Expand Down Expand Up @@ -174,7 +184,34 @@ impl Codex {
let (tx_sub, rx_sub) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
let (tx_event, rx_event) = async_channel::unbounded();

let user_instructions = get_user_instructions(&config).await;
let loaded_skills = if config.features.enabled(Feature::Skills) {
Some(load_skills(&config))
} else {
None
};

if let Some(outcome) = &loaded_skills {
for err in &outcome.errors {
error!(
"failed to load skill {}: {}",
err.path.display(),
err.message
);
}
}

let skills_outcome = loaded_skills.clone();

let user_instructions = get_user_instructions(
&config,
skills_outcome.as_ref().and_then(|outcome| {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why aren't we passin outcome.skills directly?

outcome
.errors
.is_empty()
.then_some(outcome.skills.as_slice())
}),
)
.await;

let exec_policy = load_exec_policy_for_features(&config.features, &config.codex_home)
.await
Expand Down Expand Up @@ -202,6 +239,7 @@ impl Codex {

// Generate a unique ID for the lifetime of this Codex session.
let session_source_clone = session_configuration.session_source.clone();

let session = Session::new(
session_configuration,
config.clone(),
Expand All @@ -210,6 +248,7 @@ impl Codex {
tx_event.clone(),
conversation_history,
session_source_clone,
skills_outcome.clone(),
)
.await
.map_err(|e| {
Expand Down Expand Up @@ -466,6 +505,7 @@ impl Session {
}
}

#[allow(clippy::too_many_arguments)]
async fn new(
session_configuration: SessionConfiguration,
config: Arc<Config>,
Expand All @@ -474,6 +514,7 @@ impl Session {
tx_event: Sender<Event>,
initial_history: InitialHistory,
session_source: SessionSource,
skills: Option<SkillLoadOutcome>,
) -> anyhow::Result<Arc<Self>> {
debug!(
"Configuring session: model={}; provider={:?}",
Expand Down Expand Up @@ -580,7 +621,7 @@ impl Session {
.await
.map(Arc::new);
}
let state = SessionState::new(session_configuration.clone());
let state = SessionState::new(session_configuration.clone(), skills.clone());

let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
Expand Down Expand Up @@ -609,6 +650,7 @@ impl Session {
// Dispatch the SessionConfiguredEvent first and then report any errors.
// If resuming, include converted initial messages in the payload so UIs can render them immediately.
let initial_messages = initial_history.get_event_msgs();
let skill_load_outcome = skill_load_outcome_for_client(skills.as_ref());

let events = std::iter::once(Event {
id: INITIAL_SUBMIT_ID.to_owned(),
Expand All @@ -623,6 +665,7 @@ impl Session {
history_log_id,
history_entry_count,
initial_messages,
skill_load_outcome,
rollout_path,
}),
})
Expand Down Expand Up @@ -1302,6 +1345,54 @@ impl Session {
}
}

async fn inject_skills(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if re-parsing user message in core is the right approach.

It might be better to pass selected skills as one of UserInputs or as an explicit parameter on UserTurn similar to how we handle images.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other problem with the current approach is that new skills added during the session won't be mentionable. And the contents of the skills won't be reloaded.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please mode this and related skill helpers into skill.rs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also take session as a reference instead of having an instance method on it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to keep the size of codex.rs sane

&self,
turn_context: &TurnContext,
user_input: &[UserInput],
) -> Vec<ResponseItem> {
if user_input.is_empty() || !self.enabled(Feature::Skills) {
return Vec::new();
}

let skills = {
let state = self.state.lock().await;
state
.skills
.as_ref()
.map(|outcome| outcome.skills.clone())
.unwrap_or_default()
};

let mentioned_skills = collect_explicit_skill_mentions(user_input, &skills);
if mentioned_skills.is_empty() {
return Vec::new();
}

let mut injections: Vec<ResponseItem> = Vec::with_capacity(mentioned_skills.len());
for skill in mentioned_skills {
match fs::read_to_string(&skill.path).await {
Ok(contents) => {
injections.push(ResponseItem::from(SkillInstructions {
name: skill.name,
path: skill.path.to_string_lossy().into_owned(),
contents,
}));
}
Err(err) => {
let message = format!(
"Failed to load skill {} at {}: {err:#}",
skill.name,
skill.path.display()
);
self.send_event(turn_context, EventMsg::Warning(WarningEvent { message }))
.await;
}
}
}

injections
}

pub(crate) async fn notify_background_event(
&self,
turn_context: &TurnContext,
Expand Down Expand Up @@ -2017,6 +2108,66 @@ async fn spawn_review_thread(
.await;
}

fn collect_explicit_skill_mentions(
inputs: &[UserInput],
skills: &[SkillMetadata],
) -> Vec<SkillMetadata> {
let mut selected: Vec<SkillMetadata> = Vec::new();
let mut seen: HashSet<String> = HashSet::new();

for input in inputs {
if let UserInput::Skill { name, path } = input
&& seen.insert(name.clone())
&& let Some(skill) = skills
.iter()
.find(|s| s.name == *name && paths_match(&s.path, path))
{
selected.push(skill.clone());
}
}

selected
}

fn paths_match(a: &Path, b: &Path) -> bool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this check?

if a == b {
return true;
}

let Ok(ca) = std::fs::canonicalize(a) else {
return false;
};
let Ok(cb) = std::fs::canonicalize(b) else {
return false;
};

ca == cb
}

fn skill_load_outcome_for_client(
outcome: Option<&SkillLoadOutcome>,
) -> Option<SkillLoadOutcomeInfo> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can move SkillLoadOutcome to the protocol directly and avoid double copies?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the protocol type inside core would couple the loader/cache to the external serde/TS model. The current split keeps boundaries clean, and the mapping is trivial. wdyt?

outcome.map(|outcome| SkillLoadOutcomeInfo {
skills: outcome
.skills
.iter()
.map(|skill| SkillInfo {
name: skill.name.clone(),
description: skill.description.clone(),
path: skill.path.clone(),
})
.collect(),
errors: outcome
.errors
.iter()
.map(|err| SkillErrorInfo {
path: err.path.clone(),
message: err.message.clone(),
})
.collect(),
})
}

/// Takes a user message as input and runs a loop where, at each turn, the model
/// replies with either:
///
Expand Down Expand Up @@ -2045,11 +2196,18 @@ pub(crate) async fn run_task(
});
sess.send_event(&turn_context, event).await;

let skill_injections = sess.inject_skills(&turn_context, &input).await;

let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
let response_item: ResponseItem = initial_input_for_turn.clone().into();
sess.record_response_item_and_emit_turn_item(turn_context.as_ref(), response_item)
.await;

if !skill_injections.is_empty() {
sess.record_conversation_items(&turn_context, &skill_injections)
.await;
}

sess.maybe_start_ghost_snapshot(Arc::clone(&turn_context), cancellation_token.child_token())
.await;
let mut last_agent_message: Option<String> = None;
Expand Down Expand Up @@ -2603,7 +2761,7 @@ mod tests {
session_source: SessionSource::Exec,
};

let mut state = SessionState::new(session_configuration);
let mut state = SessionState::new(session_configuration, None);
let initial = RateLimitSnapshot {
primary: Some(RateLimitWindow {
used_percent: 10.0,
Expand Down Expand Up @@ -2674,7 +2832,7 @@ mod tests {
session_source: SessionSource::Exec,
};

let mut state = SessionState::new(session_configuration);
let mut state = SessionState::new(session_configuration, None);
let initial = RateLimitSnapshot {
primary: Some(RateLimitWindow {
used_percent: 15.0,
Expand Down Expand Up @@ -2880,7 +3038,7 @@ mod tests {
let otel_event_manager =
otel_event_manager(conversation_id, config.as_ref(), &model_family);

let state = SessionState::new(session_configuration.clone());
let state = SessionState::new(session_configuration.clone(), None);

let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
Expand Down Expand Up @@ -2962,7 +3120,7 @@ mod tests {
let otel_event_manager =
otel_event_manager(conversation_id, config.as_ref(), &model_family);

let state = SessionState::new(session_configuration.clone());
let state = SessionState::new(session_configuration.clone(), None);

let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
Expand Down
29 changes: 20 additions & 9 deletions codex-rs/core/src/event_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use codex_protocol::user_input::UserInput;
use tracing::warn;
use uuid::Uuid;

use crate::user_instructions::SkillInstructions;
use crate::user_instructions::UserInstructions;
use crate::user_shell_command::is_user_shell_command_text;

Expand All @@ -23,7 +24,9 @@ fn is_session_prefix(text: &str) -> bool {
}

fn parse_user_message(message: &[ContentItem]) -> Option<UserMessageItem> {
if UserInstructions::is_user_instructions(message) {
if UserInstructions::is_user_instructions(message)
|| SkillInstructions::is_skill_instructions(message)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs a test. see below in this file

{
return None;
}

Expand Down Expand Up @@ -198,14 +201,22 @@ mod tests {
text: "# AGENTS.md instructions for test_directory\n\n<INSTRUCTIONS>\ntest_text\n</INSTRUCTIONS>".to_string(),
}],
},
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "<user_shell_command>echo 42</user_shell_command>".to_string(),
}],
},
];
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "<SKILL name=\"demo\" path=\"skills/demo/SKILL.md\">\nbody\n</SKILL>"
.to_string(),
}],
},
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "<user_shell_command>echo 42</user_shell_command>".to_string(),
}],
},
];

for item in items {
let turn_item = parse_turn_item(&item);
Expand Down
Loading
Loading