Skip to content

Commit

Permalink
Merge pull request #76 from meysamhadeli/feat/support-diffrent-keys-f…
Browse files Browse the repository at this point in the history
…or-chat-and-embeddings

Feat/support different keys for chat and embeddings
  • Loading branch information
meysamhadeli authored Nov 16, 2024
2 parents 2f47379 + 608e30b commit e455797
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 17 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ and suggest improvements or new code based on your context. This AI-powered tool
1. **RAG** (Retrieval-Augmented Generation)

2. **Summarize Full Context of Code**.
2. **Summarize Full Context of Code with Tree-sitter**.

Each method has its own benefits and is chosen depending on the specific needs of the request. Below is a description of each method.

Expand All @@ -27,8 +27,8 @@ codai **retrieves the just necessary context**, which is then sent with the user
helpful responses, making it the recommended method.

## 🧩 Summarize Full Context of Code
Another approach involves creating a **summary of the full context of project** and sending it to the AI. When a **user requests a specific part of code**,
the system can **retrieve the full context for just that section**. This method also **saves tokens** because it **sends only relevant parts**, but
Another approach involves creating a **summary of the full context of project** with **Tree-sitter** and in this approach we just send the **signature body of our code** without **full implementation of code block** to the AI. When a **user requests a specific part of code**,
the system can **retrieve the full context for just that section**. This approach also **saves tokens** because it just **sends only completed parts**, but
it usually uses **slightly more tokens than the RAG method**. In **RAG**, only the **related context send to the AI** for **saving even more tokens**.


Expand All @@ -44,12 +44,14 @@ To use codai, you need to set your environment variable for the API key.

For `Bash`, use:
```bash
export API_KEY="your_api_key"
export CHAT_API_KEY="your_chat_api_key"
export EMBEDDINGS_API_KEY="your_embeddings_api_key" #(Optional, If you want use RAG.)
```

For `PowerShell`, use:
```powershell
$env:API_KEY="your_api_key""
$env:CHAT_API_KEY="your_chat_api_key"
$env:EMBEDDINGS_API_KEY="your_embeddings_api_key" #(Optional, If you want use RAG.)
```
### 🔧 Configuration
`codai` requires a `config.yml` file in the `root of your working directory` or using `environment variables` to set below configs `globally` as a configuration.
Expand All @@ -59,7 +61,7 @@ The `config` file should be like following example base on your `AI provider`:
**config.yml**
```yml
ai_provider_config:
provider_name: "openai" # openai | ollama
provider_name: "openai" # openai | ollama | azure-openai
chat_completion_url: "https://api.openai.com/v1/chat/completions"
chat_completion_model: "gpt-4o"
embedding_url: "https://api.openai.com/v1/embeddings" #(Optional, If you want use RAG.)
Expand Down
15 changes: 10 additions & 5 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ var defaultConfig = Config{
EncodingFormat: "float",
Temperature: 0.2,
Threshold: 0,
ApiKey: "",
ChatApiKey: "",
EmbeddingsApiKey: "",
},
}

Expand All @@ -55,7 +56,8 @@ func LoadConfigs(rootCmd *cobra.Command, cwd string) *Config {
viper.SetDefault("ai_provider_config.encoding_format", defaultConfig.AIProviderConfig.EncodingFormat)
viper.SetDefault("ai_provider_config.temperature", defaultConfig.AIProviderConfig.Temperature)
viper.SetDefault("ai_provider_config.threshold", defaultConfig.AIProviderConfig.Threshold)
viper.SetDefault("ai_provider_config.api_key", defaultConfig.AIProviderConfig.ApiKey)
viper.SetDefault("ai_provider_config.chat_api_key", defaultConfig.AIProviderConfig.ChatApiKey)
viper.SetDefault("ai_provider_config.embeddings_api_key", defaultConfig.AIProviderConfig.EmbeddingsApiKey)

// Automatically read environment variables
viper.AutomaticEnv() // This will look for variables that match config keys directly
Expand Down Expand Up @@ -109,7 +111,8 @@ func bindEnv() {
_ = viper.BindEnv("ai_provider_config.embedding_model", "EMBEDDING_MODEL")
_ = viper.BindEnv("ai_provider_config.temperature", "TEMPERATURE")
_ = viper.BindEnv("ai_provider_config.threshold", "THRESHOLD")
_ = viper.BindEnv("ai_provider_config.api_key", "API_KEY")
_ = viper.BindEnv("ai_provider_config.chat_api_key", "CHAT_API_KEY")
_ = viper.BindEnv("ai_provider_config.embeddings_api_key", "EMBEDDINGS_API_KEY")
}

// bindFlags binds the CLI flags to configuration values.
Expand All @@ -123,7 +126,8 @@ func bindFlags(rootCmd *cobra.Command) {
_ = viper.BindPFlag("ai_provider_config.embedding_model", rootCmd.Flags().Lookup("embedding_model"))
_ = viper.BindPFlag("ai_provider_config.temperature", rootCmd.Flags().Lookup("temperature"))
_ = viper.BindPFlag("ai_provider_config.threshold", rootCmd.Flags().Lookup("threshold"))
_ = viper.BindPFlag("ai_provider_config.api_key", rootCmd.Flags().Lookup("api_key"))
_ = viper.BindPFlag("ai_provider_config.chat_api_key", rootCmd.Flags().Lookup("chat_api_key"))
_ = viper.BindPFlag("ai_provider_config.embeddings_api_key", rootCmd.Flags().Lookup("embeddings_api_key"))
}

// InitFlags initializes the flags for the root command.
Expand All @@ -140,5 +144,6 @@ func InitFlags(rootCmd *cobra.Command) {
rootCmd.PersistentFlags().String("embedding_model", defaultConfig.AIProviderConfig.EmbeddingModel, "Specifies the AI model used for generating text embeddings (e.g., 'text-embedding-ada-002'). This model converts text into vector representations for similarity comparisons.")
rootCmd.PersistentFlags().Float32("temperature", defaultConfig.AIProviderConfig.Temperature, "Adjusts the AI model’s creativity by setting a temperature value. Higher values result in more creative or varied responses, while lower values make them more focused (e.g., value should be between '0 - 1' and default is '0.2').")
rootCmd.PersistentFlags().Float64("threshold", defaultConfig.AIProviderConfig.Threshold, "Sets the threshold for similarity calculations in AI systems. Higher values will require closer matches and should be careful not to lose matches, while lower values provide a wider range of results to prevent losing any matches. (e.g., value should be between '0.2 - 1' and default is '0.3').")
rootCmd.PersistentFlags().String("api_key", defaultConfig.AIProviderConfig.ApiKey, "The API key used to authenticate with the AI service provider.")
rootCmd.PersistentFlags().String("chat_api_key", defaultConfig.AIProviderConfig.ChatApiKey, "The chat API key used to authenticate with the AI service provider.")
rootCmd.PersistentFlags().String("embeddings_api_key", defaultConfig.AIProviderConfig.EmbeddingsApiKey, "The embeddings API key used to authenticate with the AI service provider.")
}
6 changes: 4 additions & 2 deletions providers/ai_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ type AIProviderConfig struct {
EncodingFormat string `mapstructure:"encoding_format"`
MaxTokens int `mapstructure:"max_tokens"`
Threshold float64 `mapstructure:"threshold"`
ApiKey string `mapstructure:"api_key"`
ChatApiKey string `mapstructure:"chat_api_key"`
EmbeddingsApiKey string `mapstructure:"embeddings_api_key"`
}

// ProviderFactory creates a Provider based on the given provider config.
Expand All @@ -45,7 +46,8 @@ func ProviderFactory(config *AIProviderConfig, tokenManagement contracts.ITokenM
EmbeddingModel: config.EmbeddingModel,
ChatCompletionURL: config.ChatCompletionURL,
EmbeddingURL: config.EmbeddingURL,
ApiKey: config.ApiKey,
ChatApiKey: config.ChatApiKey,
EmbeddingsApiKey: config.EmbeddingsApiKey,
MaxTokens: config.MaxTokens,
Threshold: config.Threshold,
TokenManagement: tokenManagement,
Expand Down
10 changes: 6 additions & 4 deletions providers/openai/openai_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ type OpenAIConfig struct {
ChatCompletionModel string
Temperature float32
EncodingFormat string
ApiKey string
ChatApiKey string
EmbeddingsApiKey string
MaxTokens int
Threshold float64
TokenManagement contracts.ITokenManagement
Expand All @@ -42,7 +43,8 @@ func NewOpenAIProvider(config *OpenAIConfig) contracts.IAIProvider {
EncodingFormat: config.EncodingFormat,
MaxTokens: config.MaxTokens,
Threshold: config.Threshold,
ApiKey: config.ApiKey,
ChatApiKey: config.ChatApiKey,
EmbeddingsApiKey: config.EmbeddingsApiKey,
TokenManagement: config.TokenManagement,
Name: config.Name,
}
Expand Down Expand Up @@ -71,7 +73,7 @@ func (openAIProvider *OpenAIConfig) EmbeddingRequest(ctx context.Context, prompt

// Set required headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", openAIProvider.ApiKey)
req.Header.Set("api-key", openAIProvider.EmbeddingsApiKey)

// Make the HTTP request
client := &http.Client{}
Expand Down Expand Up @@ -154,7 +156,7 @@ func (openAIProvider *OpenAIConfig) ChatCompletionRequest(ctx context.Context, u
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", openAIProvider.ApiKey)
req.Header.Set("api-key", openAIProvider.ChatApiKey)

client := &http.Client{}
resp, err := client.Do(req)
Expand Down

0 comments on commit e455797

Please sign in to comment.