Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions crates/forge_api/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ pub trait API: Sync + Send {
/// Retrieves the provider configuration for the default agent
async fn get_default_provider(&self) -> anyhow::Result<Provider<Url>>;

/// Sets the default provider for all the agents
async fn set_default_provider(&self, provider_id: ProviderId) -> anyhow::Result<()>;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop the setters and getters for default model also from api. Update the implementation to use only the model+provider selector to select them atomically.


/// Updates the caller's default provider and model together, ensuring all
/// commands resolve a consistent pair without requiring a follow-up model
/// selection call.
Expand Down
7 changes: 0 additions & 7 deletions crates/forge_api/src/forge_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,6 @@ impl<A: Services, F: CommandInfra + EnvironmentInfra + SkillRepository + GrpcInf
agent_provider_resolver.get_provider(Some(agent_id)).await
}

async fn set_default_provider(&self, provider_id: ProviderId) -> anyhow::Result<()> {
let result = self.services.set_default_provider(provider_id).await;
// Invalidate cache for agents
let _ = self.services.reload_agents().await;
result
}

async fn user_info(&self) -> Result<Option<User>> {
let provider = self.get_default_provider().await?;
if let Some(api_key) = provider.api_key() {
Expand Down
4 changes: 0 additions & 4 deletions crates/forge_app/src/command_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,6 @@ mod tests {
Ok(ProviderId::OPENAI)
}

async fn set_default_provider(&self, _provider_id: ProviderId) -> Result<()> {
Ok(())
}

async fn get_provider_model(
&self,
_provider_id: Option<&ProviderId>,
Expand Down
6 changes: 4 additions & 2 deletions crates/forge_app/src/git_app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,10 @@ impl<S: Services> GitApp<S> {
// Resolve provider and model: commit config takes priority over agent defaults.
// If the configured provider is unavailable (e.g. logged out), fall back to the
// agent's provider/model with a warning.
let (provider, model) = match commit_config.and_then(|c| c.provider.zip(c.model)) {
Some((provider_id, commit_model)) => {
let (provider, model) = match commit_config {
Some(c) => {
let provider_id = c.provider;
let commit_model = c.model;
match self.services.get_provider(provider_id).await {
Ok(provider) => {
match self.services.refresh_provider_credential(provider).await {
Expand Down
15 changes: 0 additions & 15 deletions crates/forge_app/src/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,6 @@ pub trait AppConfigService: Send + Sync {
/// Gets the user's default provider ID.
async fn get_default_provider(&self) -> anyhow::Result<ProviderId>;

/// Sets the user's default provider preference.
async fn set_default_provider(
&self,
provider_id: forge_domain::ProviderId,
) -> anyhow::Result<()>;

/// Gets the user's default model for a specific provider or the currently
/// active provider. When provider_id is None, uses the currently active
/// provider.
Expand Down Expand Up @@ -978,15 +972,6 @@ impl<I: Services> AppConfigService for I {
self.config_service().get_default_provider().await
}

async fn set_default_provider(
&self,
provider_id: forge_domain::ProviderId,
) -> anyhow::Result<()> {
self.config_service()
.set_default_provider(provider_id)
.await
}

async fn get_provider_model(
&self,
provider_id: Option<&forge_domain::ProviderId>,
Expand Down
23 changes: 14 additions & 9 deletions crates/forge_config/src/legacy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,23 @@ impl LegacyConfig {
/// Converts a [`LegacyConfig`] into the fields of [`ForgeConfig`] that it
/// covers, leaving all other fields at their defaults (`None`).
fn into_forge_config(self) -> ForgeConfig {
let session = self.provider.as_deref().map(|provider_id| {
let model_id = self.model.get(provider_id).cloned();
ModelConfig { provider_id: Some(provider_id.to_string()), model_id }
let session = self.provider.as_deref().and_then(|provider_id| {
self.model
.get(provider_id)
.map(|model_id| ModelConfig::new(provider_id, model_id.as_str()))
});

let commit = self
.commit
.map(|c| ModelConfig { provider_id: c.provider, model_id: c.model });
let commit = self.commit.and_then(|c| {
c.provider
.zip(c.model)
.map(|(pid, mid)| ModelConfig::new(pid, mid))
});

let suggest = self
.suggest
.map(|s| ModelConfig { provider_id: s.provider, model_id: s.model });
let suggest = self.suggest.and_then(|s| {
s.provider
.zip(s.model)
.map(|(pid, mid)| ModelConfig::new(pid, mid))
});

ForgeConfig { session, commit, suggest, ..Default::default() }
}
Expand Down
17 changes: 11 additions & 6 deletions crates/forge_config/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@ pub type ProviderId = String;
pub type ModelId = String;

/// Pairs a provider and model together for a specific operation.
#[derive(
Default, Debug, Setters, Clone, PartialEq, Serialize, Deserialize, JsonSchema, fake::Dummy,
)]
#[setters(strip_option, into)]
#[derive(Debug, Setters, Clone, PartialEq, Serialize, Deserialize, JsonSchema, fake::Dummy)]
#[setters(into)]
pub struct ModelConfig {
/// The provider to use for this operation.
pub provider_id: Option<String>,
pub provider_id: String,
/// The model to use for this operation.
pub model_id: Option<String>,
pub model_id: String,
}

impl ModelConfig {
/// Creates a new [`ModelConfig`] with the given provider and model IDs.
pub fn new(provider_id: impl Into<String>, model_id: impl Into<String>) -> Self {
Self { provider_id: provider_id.into(), model_id: model_id.into() }
}
}
15 changes: 3 additions & 12 deletions crates/forge_config/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,7 @@ mod tests {
// carries session/commit/suggest (all other fields are None) and layer
// it on top of the embedded defaults. The default values must survive.
let legacy = ForgeConfig {
session: Some(ModelConfig {
provider_id: Some("anthropic".to_string()),
model_id: Some("claude-3".to_string()),
}),
session: Some(ModelConfig::new("anthropic", "claude-3")),
..Default::default()
};
let legacy_toml = toml_edit::ser::to_string_pretty(&legacy).unwrap();
Expand All @@ -215,10 +212,7 @@ mod tests {
// Session should come from the legacy layer
assert_eq!(
actual.session,
Some(ModelConfig {
provider_id: Some("anthropic".to_string()),
model_id: Some("claude-3".to_string()),
})
Some(ModelConfig::new("anthropic", "claude-3"))
);

// Default values from .forge.toml must be retained, not reset to zero
Expand All @@ -242,10 +236,7 @@ mod tests {
.build()
.unwrap();

let expected = Some(ModelConfig {
provider_id: Some("fake-provider".to_string()),
model_id: Some("fake-model".to_string()),
});
let expected = Some(ModelConfig::new("fake-provider", "fake-model"));
assert_eq!(actual.session, expected);
}
}
23 changes: 11 additions & 12 deletions crates/forge_domain/src/commit_config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use derive_setters::Setters;
use merge::Merge;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

Expand All @@ -11,19 +10,19 @@ use crate::{ModelId, ProviderId};
/// generation, instead of using the active agent's provider and model. This is
/// useful when you want to use a cheaper or faster model for simple commit
/// message generation.
#[derive(Default, Debug, Clone, Serialize, Deserialize, Merge, Setters, JsonSchema, PartialEq)]
#[setters(strip_option, into)]
#[derive(Debug, Clone, Serialize, Deserialize, Setters, JsonSchema, PartialEq)]
#[setters(into)]
pub struct CommitConfig {
/// Provider ID to use for commit message generation.
/// If not specified, the active agent's provider will be used.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[merge(strategy = crate::merge::option)]
pub provider: Option<ProviderId>,
pub provider: ProviderId,

/// Model ID to use for commit message generation.
/// If not specified, the provider's default model or the active agent's
/// model will be used.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[merge(strategy = crate::merge::option)]
pub model: Option<ModelId>,
pub model: ModelId,
}

impl CommitConfig {
/// Creates a new [`CommitConfig`] with the given provider and model.
pub fn new(provider: impl Into<ProviderId>, model: impl Into<ModelId>) -> Self {
Self { provider: provider.into(), model: model.into() }
}
}
19 changes: 12 additions & 7 deletions crates/forge_domain/src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@ use crate::{Effort, ModelId, ProviderId};
///
/// Used to represent an active session, decoupled from the on-disk
/// configuration format.
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize, Setters)]
#[setters(strip_option, into)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Setters)]
#[setters(into)]
pub struct SessionConfig {
/// The active provider ID (e.g. `"anthropic"`).
pub provider_id: Option<String>,
pub provider_id: String,
/// The model ID to use with this provider.
pub model_id: Option<String>,
pub model_id: String,
}

impl SessionConfig {
/// Creates a new [`SessionConfig`] with the given provider and model IDs.
pub fn new(provider_id: impl Into<String>, model_id: impl Into<String>) -> Self {
Self { provider_id: provider_id.into(), model_id: model_id.into() }
}
}

/// All discrete mutations that can be applied to the application configuration.
Expand All @@ -27,9 +34,7 @@ pub struct SessionConfig {
/// each in order, and persist the result atomically.
#[derive(Debug, Clone, PartialEq)]
pub enum ConfigOperation {
/// Set the active provider.
SetProvider(ProviderId),
/// Set the model for the given provider.
/// Set the model for the given provider, replacing any existing session.
SetModel(ProviderId, ModelId),
/// Set the commit-message generation configuration.
SetCommitConfig(crate::CommitConfig),
Expand Down
71 changes: 23 additions & 48 deletions crates/forge_infra/src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,36 +32,17 @@ pub fn to_environment(cwd: PathBuf) -> Environment {
/// persisted config without an intermediate `Environment` round-trip.
fn apply_config_op(fc: &mut ForgeConfig, op: ConfigOperation) {
match op {
ConfigOperation::SetProvider(pid) => {
let session = fc.session.get_or_insert_with(ModelConfig::default);
session.provider_id = Some(pid.as_ref().to_string());
}
ConfigOperation::SetModel(pid, mid) => {
let pid_str = pid.as_ref().to_string();
let mid_str = mid.to_string();
let session = fc.session.get_or_insert_with(ModelConfig::default);
if session.provider_id.as_deref() == Some(&pid_str) {
session.model_id = Some(mid_str);
} else {
fc.session =
Some(ModelConfig { provider_id: Some(pid_str), model_id: Some(mid_str) });
}
fc.session = Some(ModelConfig::new(&**pid, mid.as_str()));
}
ConfigOperation::SetCommitConfig(commit) => {
fc.commit = commit
.provider
.as_ref()
.zip(commit.model.as_ref())
.map(|(pid, mid)| ModelConfig {
provider_id: Some(pid.as_ref().to_string()),
model_id: Some(mid.to_string()),
});
fc.commit = Some(ModelConfig::new(&**commit.provider, commit.model.as_str()));
}
ConfigOperation::SetSuggestConfig(suggest) => {
fc.suggest = Some(ModelConfig {
provider_id: Some(suggest.provider.as_ref().to_string()),
model_id: Some(suggest.model.to_string()),
});
fc.suggest = Some(ModelConfig::new(
&**suggest.provider,
suggest.model.as_str(),
));
}
ConfigOperation::SetReasoningEffort(effort) => {
let config_effort = match effort {
Expand Down Expand Up @@ -228,33 +209,33 @@ mod tests {
}

#[test]
fn test_apply_config_op_set_provider() {
use forge_domain::ProviderId;
fn test_apply_config_op_set_model_creates_complete_session() {
use forge_domain::{ModelId, ProviderId};

let mut fixture = ForgeConfig::default();
apply_config_op(
&mut fixture,
ConfigOperation::SetProvider(ProviderId::ANTHROPIC),
ConfigOperation::SetModel(
ProviderId::ANTHROPIC,
ModelId::new("claude-3-5-sonnet-20241022"),
),
);

let actual = fixture
.session
.as_ref()
.and_then(|s| s.provider_id.as_deref());
let expected = Some("anthropic");
let actual_provider = fixture.session.as_ref().map(|s| s.provider_id.as_str());
let actual_model = fixture.session.as_ref().map(|s| s.model_id.as_str());
let expected_provider = Some("anthropic");
let expected_model = Some("claude-3-5-sonnet-20241022");

assert_eq!(actual, expected);
assert_eq!(actual_provider, expected_provider);
assert_eq!(actual_model, expected_model);
}

#[test]
fn test_apply_config_op_set_model_matching_provider() {
use forge_domain::{ModelId, ProviderId};

let mut fixture = ForgeConfig {
session: Some(ModelConfig {
provider_id: Some("anthropic".to_string()),
model_id: None,
}),
session: Some(ModelConfig::new("anthropic", "old-model")),
..Default::default()
};

Expand All @@ -266,7 +247,7 @@ mod tests {
),
);

let actual = fixture.session.as_ref().and_then(|s| s.model_id.as_deref());
let actual = fixture.session.as_ref().map(|s| s.model_id.as_str());
let expected = Some("claude-3-5-sonnet-20241022");

assert_eq!(actual, expected);
Expand All @@ -277,10 +258,7 @@ mod tests {
use forge_domain::{ModelId, ProviderId};

let mut fixture = ForgeConfig {
session: Some(ModelConfig {
provider_id: Some("openai".to_string()),
model_id: Some("gpt-4".to_string()),
}),
session: Some(ModelConfig::new("openai", "gpt-4")),
..Default::default()
};

Expand All @@ -292,11 +270,8 @@ mod tests {
),
);

let actual_provider = fixture
.session
.as_ref()
.and_then(|s| s.provider_id.as_deref());
let actual_model = fixture.session.as_ref().and_then(|s| s.model_id.as_deref());
let actual_provider = fixture.session.as_ref().map(|s| s.provider_id.as_str());
let actual_model = fixture.session.as_ref().map(|s| s.model_id.as_str());

assert_eq!(actual_provider, Some("anthropic"));
assert_eq!(actual_model, Some("claude-3-5-sonnet-20241022"));
Expand Down
Loading
Loading