Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepinfra Embedder #1856

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/components/embedders/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ Config is essential for:

Here's a comprehensive list of all parameters that can be used across different embedders:

| Parameter | Description |
|-----------|-------------|
| Parameter | Description | Model |
|-----------|-------------|-------|
| `model` | Embedding model to use |
| `api_key` | API key of the provider |
| `embedding_dims` | Dimensions of the embedding model |
Expand All @@ -54,6 +54,7 @@ Here's a comprehensive list of all parameters that can be used across different
| `azure_kwargs` | Key-Value arguments for the AzureOpenAI embedding model |
| `openai_base_url` | Base URL for OpenAI API | OpenAI |
| `vertex_credentials_json` | Path to the Google Cloud credentials JSON file for VertexAI |
| `encoding_format` | The encoding format of the embedding | DeepInfra |


## Supported Embedding Models
Expand Down
40 changes: 40 additions & 0 deletions docs/components/embedders/models/deepinfra.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
---
title: DeepInfra
---

To use DeepInfra embedding models, set the `DEEPINFRA_TOKEN` environment variable. You can obtain the OpenAI API key from the [DeepInfra Platform](https://deepinfra.com/dash/api_keys).

### Usage

```python
import os
from mem0 import Memory

os.environ["DEEPINFRA_TOKEN"] = "your_api_key"

config = {
"embedder": {
"provider": "deepinfra",
"config": {
"model": "BAAI/bge-large-en-v1.5",
"embedding_dims": 1024,
"encoding_format": "float",
}
}
}

m = Memory.from_config(config)
m.add("I'm visiting Paris", user_id="john")
```

### Config

Here are the parameters available for configuring DeepInfra embedder:

| Parameter | Description | Default Value |
| --- | --- | --- |
| `model` | The name of the embedding model to use | `"BAAI/bge-large-en-v1.5"` |
| `embedding_dims` | Dimensions of the embedding model | `1024` |
| `api_key` | The DEEPINFRA API key | `None` |
| `encoding_format` | The encoding format of the embedding | `"float"` |

1 change: 1 addition & 0 deletions docs/components/embedders/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the list of supported embedders below.
<Card title="Azure OpenAI" href="/components/embedders/models/azure_openai"></Card>
<Card title="Ollama" href="/components/embedders/models/ollama"></Card>
<Card title="Hugging Face" href="/components/embedders/models/huggingface"></Card>
<Card title="DeepInfra" href="/components/embedders/models/deepinfra"></Card>
<Card title="Gemini" href="/components/embedders/models/gemini"></Card>
<Card title="Vertex AI" href="/components/embedders/models/vertexai"></Card>
</CardGroup>
Expand Down
3 changes: 2 additions & 1 deletion docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@
"components/embedders/models/azure_openai",
"components/embedders/models/ollama",
"components/embedders/models/huggingface",
"components/embedders/models/gemini"
"components/embedders/models/gemini",
"components/embedders/models/deepinfra"
]
}
]
Expand Down
12 changes: 9 additions & 3 deletions mem0/configs/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from abc import ABC
from typing import Dict, Optional, Union
from mem0.configs.base import AzureConfig
from typing import Optional, Union, Dict, Literal

import httpx

from mem0.configs.base import AzureConfig


class BaseEmbedderConfig(ABC):
"""
Expand All @@ -16,6 +15,8 @@ def __init__(
model: Optional[str] = None,
api_key: Optional[str] = None,
embedding_dims: Optional[int] = None,
# OpenAI specific (used only in deepinfra at the moment)
encoding_format: Optional[Literal["float", "base64"]] = None,
# Ollama specific
ollama_base_url: Optional[str] = None,
# Openai specific
Expand All @@ -35,6 +36,8 @@ def __init__(
:type api_key: Optional[str], optional
:param embedding_dims: The number of dimensions in the embedding, defaults to None
:type embedding_dims: Optional[int], optional
:param encoding_format: The encoding format of the embedding, defaults to None
:type encoding_format: Optional[str] is base64 default for openai, optional
:param ollama_base_url: Base URL for the Ollama API, defaults to None
:type ollama_base_url: Optional[str], optional
:param model_kwargs: key-value arguments for the huggingface embedding model, defaults a dict inside init
Expand All @@ -52,6 +55,9 @@ def __init__(
self.openai_base_url = openai_base_url
self.embedding_dims = embedding_dims

# OpenAI specific (used only in deepinfra at the moment)
self.encoding_format = encoding_format if encoding_format else "base64"

# AzureOpenAI specific
self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None

Expand Down
2 changes: 1 addition & 1 deletion mem0/embeddings/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class EmbedderConfig(BaseModel):
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai"]:
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", 'deepinfra']:
return v
else:
raise ValueError(f"Unsupported embedding provider: {provider}")
38 changes: 38 additions & 0 deletions mem0/embeddings/deepinfra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
from typing import Optional

from openai import OpenAI

from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase


class DeepInfraEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)

if not self.config.model:
self.config.model = "BAAI/bge-large-en-v1.5"
self.config.embedding_dims = 1024
self.config.encoding_format = "float"

api_key = self.config.api_key or os.getenv("DEEPINFRA_TOKEN")
base_url = "https://api.deepinfra.com/v1/openai"
self.client = OpenAI(api_key=api_key, base_url=base_url)

def embed(self, text):
"""
Get the embedding for the given text using OpenAI.

Args:
text (str): The text to embed.

Returns:
list: The embedding vector.
"""
text = text.replace("\n", " ")
return (
self.client.embeddings.create(input=[text], model=self.config.model, encoding_format=self.config.encoding_format)
.data[0]
.embedding
)
1 change: 1 addition & 0 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class EmbedderFactory:
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
"azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
"gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
"deepinfra": "mem0.embeddings.deepinfra.DeepInfraEmbedding",
}

@classmethod
Expand Down
39 changes: 39 additions & 0 deletions tests/embeddings/test_deepinfra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

from unittest.mock import patch, MagicMock
import pytest
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.deepinfra import DeepInfraEmbedding # Update import as needed


@pytest.fixture
def mock_openai():
with patch("mem0.embeddings.deepinfra.OpenAI") as mock_openai_client:
yield mock_openai_client

@pytest.fixture
def config():
return BaseEmbedderConfig(
api_key="dummy_api_key",
model="BAAI/bge-large-en-v1.5",
embedding_dims=1024,
encoding_format="float"
)


def test_embed_query(mock_openai, config):

mock_embedding_response = MagicMock()
mock_embedding_response.data[0].embedding = [0.1, 0.2, 0.3, 0.4]
mock_openai.return_value.embeddings.create.return_value = mock_embedding_response

embedder = DeepInfraEmbedding(config)

text = "Hello, world!"
embedding = embedder.embed(text)

assert embedding == [0.1, 0.2, 0.3, 0.4]
mock_openai.return_value.embeddings.create.assert_called_once_with(
input=["Hello, world!"],
model="BAAI/bge-large-en-v1.5",
encoding_format="float"
)
Loading