Skip to content

Commit

Permalink
[Improvement] Set a default app id if not provided in the app configu…
Browse files Browse the repository at this point in the history
…ration (#1300)
  • Loading branch information
deshraj authored Mar 2, 2024
1 parent 8d7e8b6 commit faacfeb
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 41 deletions.
40 changes: 23 additions & 17 deletions docs/get-started/quickstart.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,41 +31,47 @@ This section gives a quickstart example of using Mistral as the Open source LLM
We are using Mistral hosted at Hugging Face, so will you need a Hugging Face token to run this example. Its *free* and you can create one [here](https://huggingface.co/docs/hub/security-tokens).

<CodeGroup>
```python quickstart.py
```python huggingface_demo.py
import os
# replace this with your HF key
# Replace this with your HF token
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "hf_xxxx"

from embedchain import App
app = App.from_config("mistral.yaml")

config = {
'llm': {
'provider': 'huggingface',
'config': {
'model': 'mistralai/Mistral-7B-Instruct-v0.2',
'top_p': 0.5
}
},
'embedder': {
'provider': 'huggingface',
'config': {
'model': 'sentence-transformers/all-mpnet-base-v2'
}
}
}
app = App.from_config(config=config)
app.add("https://www.forbes.com/profile/elon-musk")
app.add("https://en.wikipedia.org/wiki/Elon_Musk")
app.query("What is the net worth of Elon Musk today?")
# Answer: The net worth of Elon Musk today is $258.7 billion.
```
```yaml mistral.yaml
llm:
provider: huggingface
config:
model: 'mistralai/Mistral-7B-Instruct-v0.2'
top_p: 0.5
embedder:
provider: huggingface
config:
model: 'sentence-transformers/all-mpnet-base-v2'
```
</CodeGroup>

## Paid Models

In this section, we will use both LLM and embedding model from OpenAI.

```python quickstart.py
```python openai_demo.py
import os
# replace this with your OpenAI key
from embedchain import App

# Replace this with your OpenAI key
os.environ["OPENAI_API_KEY"] = "sk-xxxx"

from embedchain import App
app = App()
app.add("https://www.forbes.com/profile/elon-musk")
app.add("https://en.wikipedia.org/wiki/Elon_Musk")
Expand Down
17 changes: 6 additions & 11 deletions embedchain/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,15 @@
import json
import logging
import os
import uuid
from typing import Any, Optional, Union

import requests
import yaml
from tqdm import tqdm

from embedchain.cache import (
Config,
ExactMatchEvaluation,
SearchDistanceEvaluation,
cache,
gptcache_data_manager,
gptcache_pre_function,
)
from embedchain.cache import (Config, ExactMatchEvaluation,
SearchDistanceEvaluation, cache,
gptcache_data_manager, gptcache_pre_function)
from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.core.db.database import get_session, init_db, setup_engine
Expand All @@ -26,7 +20,8 @@
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.evaluation.base import BaseMetric
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
Groundedness)
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
Expand Down Expand Up @@ -106,7 +101,7 @@ def __init__(

self.config = config or AppConfig()
self.name = self.config.name
self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
self.config.id = self.local_id = "default-app-id" if self.config.id is None else self.config.id

if id is not None:
# Init client first since user is trying to fetch the pipeline
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "embedchain"
version = "0.1.91"
version = "0.1.92"
description = "Simplest open source retrieval(RAG) framework"
authors = [
"Taranjeet Singh <[email protected]>",
Expand Down
24 changes: 12 additions & 12 deletions tests/vectordb/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def test_incorrect_config_throws_error(self):
def test_initialize(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)

# Create a Qdrant instance
db = QdrantDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)

self.assertEqual(db.collection_name, "embedchain-store-1526")
self.assertEqual(db.collection_name, "embedchain-store-1536")
self.assertEqual(db.client, qdrant_client_mock.return_value)
qdrant_client_mock.return_value.get_collections.assert_called_once()

Expand All @@ -47,7 +47,7 @@ def test_get(self, qdrant_client_mock):

# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)

# Create a Qdrant instance
Expand All @@ -67,7 +67,7 @@ def test_add(self, uuid_mock, qdrant_client_mock):

# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)

# Create a Qdrant instance
Expand All @@ -80,9 +80,9 @@ def test_add(self, uuid_mock, qdrant_client_mock):
ids = ["123", "456"]
db.add(documents, metadatas, ids)
qdrant_client_mock.return_value.upsert.assert_called_once_with(
collection_name="embedchain-store-1526",
collection_name="embedchain-store-1536",
points=Batch(
ids=["def", "ghi"],
ids=["abc", "def"],
payloads=[
{
"identifier": "123",
Expand All @@ -103,7 +103,7 @@ def test_add(self, uuid_mock, qdrant_client_mock):
def test_query(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)

# Create a Qdrant instance
Expand All @@ -115,7 +115,7 @@ def test_query(self, qdrant_client_mock):
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})

qdrant_client_mock.return_value.search.assert_called_once_with(
collection_name="embedchain-store-1526",
collection_name="embedchain-store-1536",
query_filter=models.Filter(
must=[
models.FieldCondition(
Expand All @@ -134,7 +134,7 @@ def test_query(self, qdrant_client_mock):
def test_count(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)

# Create a Qdrant instance
Expand All @@ -143,13 +143,13 @@ def test_count(self, qdrant_client_mock):
App(config=app_config, db=db, embedding_model=embedder)

db.count()
qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1526")
qdrant_client_mock.return_value.get_collection.assert_called_once_with(collection_name="embedchain-store-1536")

@patch("embedchain.vectordb.qdrant.QdrantClient")
def test_reset(self, qdrant_client_mock):
# Set the embedder
embedder = BaseEmbedder()
embedder.set_vector_dimension(1526)
embedder.set_vector_dimension(1536)
embedder.set_embedding_fn(mock_embedding_fn)

# Create a Qdrant instance
Expand All @@ -159,7 +159,7 @@ def test_reset(self, qdrant_client_mock):

db.reset()
qdrant_client_mock.return_value.delete_collection.assert_called_once_with(
collection_name="embedchain-store-1526"
collection_name="embedchain-store-1536"
)


Expand Down

0 comments on commit faacfeb

Please sign in to comment.