Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(LLM/providers): Adding oLlama as LLM Provider #968

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions apps/api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,8 @@ RESEND_API_KEY=
# TRACE - For logging more detailed information than the DEBUG level.
# Set LOGGING_LEVEL to one of the above options to control logging output.
LOGGING_LEVEL=INFO

LLM_PROVIDER=openai # or 'ollama'
OLLAMA_URL= # Only needed if using Ollama
MODEL_NAME=
OLLAMA_EMBEDDING_MODEL=nomic-embed-text
6 changes: 3 additions & 3 deletions apps/api/src/controllers/v1/extract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { waitForJob } from "../../services/queue-jobs";
import { addScrapeJob } from "../../services/queue-jobs";
import { PlanType } from "../../types";
import { getJobPriority } from "../../lib/job-priority";
import { generateOpenAICompletions } from "../../scraper/scrapeURL/transformers/llmExtract";
import { generateLLMCompletions } from "../../scraper/scrapeURL/transformers/llmExtract";
import { isUrlBlocked } from "../../scraper/WebScraper/utils/blocklist";
import { getMapResults } from "./map";
import { buildDocument } from "../../lib/extract/build-document";
Expand Down Expand Up @@ -232,8 +232,8 @@ export async function extractController(
});
}

const completions = await generateOpenAICompletions(
logger.child({ method: "extractController/generateOpenAICompletions" }),
const completions = await generateLLMCompletions(
logger.child({ method: "extractController/generateLLMCompletions" }),
{
mode: "llm",
systemPrompt:
Expand Down
193 changes: 120 additions & 73 deletions apps/api/src/lib/ranker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,91 +4,138 @@ import OpenAI from "openai";

configDotenv();

const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});

async function getEmbedding(text: string) {
const embedding = await openai.embeddings.create({
model: "text-embedding-ada-002",
input: text,
encoding_format: "float",
});

return embedding.data[0].embedding;
type LLMProvider = 'openai' | 'ollama';

class EmbeddingError extends Error {
constructor(message: string) {
super(`Embedding Error: ${message}`);
}
}

async function getOpenAIEmbedding(text: string): Promise<number[]> {
if (!process.env.OPENAI_API_KEY) {
throw new EmbeddingError('OpenAI API key is not configured');
}

const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});

const embedding = await openai.embeddings.create({
model: "text-embedding-ada-002",
input: text,
encoding_format: "float",
});

return embedding.data[0].embedding;
}

async function getOllamaEmbedding(text: string): Promise<number[]> {
const ollamaUrl = process.env.OLLAMA_URL || 'http://localhost:11434';
const model = process.env.OLLAMA_EMBEDDING_MODEL || 'llama2:latest';

const response = await fetch(`${ollamaUrl}/api/embeddings`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model,
prompt: text,
}),
});

if (!response.ok) {
throw new EmbeddingError(`Ollama API returned ${response.status}`);
}

const data = await response.json();
return data.embedding;
}

async function getEmbedding(text: string): Promise<number[]> {
const llmProvider = (process.env.LLM_PROVIDER || 'openai') as LLMProvider;

switch (llmProvider) {
case 'openai':
return getOpenAIEmbedding(text);
case 'ollama':
return getOllamaEmbedding(text);
default:
throw new EmbeddingError(`Unsupported LLM provider: ${llmProvider}`);
}
}

const cosineSimilarity = (vec1: number[], vec2: number[]): number => {
const dotProduct = vec1.reduce((sum, val, i) => sum + val * vec2[i], 0);
const magnitude1 = Math.sqrt(vec1.reduce((sum, val) => sum + val * val, 0));
const magnitude2 = Math.sqrt(vec2.reduce((sum, val) => sum + val * val, 0));
if (magnitude1 === 0 || magnitude2 === 0) return 0;
return dotProduct / (magnitude1 * magnitude2);
const dotProduct = vec1.reduce((sum, val, i) => sum + val * vec2[i], 0);
const magnitude1 = Math.sqrt(
vec1.reduce((sum, val) => sum + val * val, 0)
);
const magnitude2 = Math.sqrt(
vec2.reduce((sum, val) => sum + val * val, 0)
);
if (magnitude1 === 0 || magnitude2 === 0) return 0;
return dotProduct / (magnitude1 * magnitude2);
};

// Function to convert text to vector
const textToVector = (searchQuery: string, text: string): number[] => {
const words = searchQuery.toLowerCase().split(/\W+/);
return words.map((word) => {
const count = (text.toLowerCase().match(new RegExp(word, "g")) || [])
.length;
return count / text.length;
});
const words = searchQuery.toLowerCase().split(/\W+/);
return words.map((word) => {
const count = (text.toLowerCase().match(new RegExp(word, "g")) || [])
.length;
return count / text.length;
});
};

async function performRanking(
linksWithContext: string[],
links: string[],
searchQuery: string,
) {
try {
// Handle invalid inputs
if (!searchQuery || !linksWithContext.length || !links.length) {
return [];
}

// Sanitize search query by removing null characters
const sanitizedQuery = searchQuery;

// Generate embeddings for the search query
const queryEmbedding = await getEmbedding(sanitizedQuery);

// Generate embeddings for each link and calculate similarity
const linksAndScores = await Promise.all(
linksWithContext.map(async (linkWithContext, index) => {
try {
const linkEmbedding = await getEmbedding(linkWithContext);
const score = cosineSimilarity(queryEmbedding, linkEmbedding);

return {
link: links[index],
linkWithContext,
score,
originalIndex: index,
};
} catch (err) {
// If embedding fails for a link, return with score 0
return {
link: links[index],
linkWithContext,
score: 0,
originalIndex: index,
};
async function performRanking(linksWithContext: string[], links: string[], searchQuery: string) {
try {
// Handle invalid inputs
if (!searchQuery || !linksWithContext.length || !links.length) {
return [];
}
}),
);

// Sort links based on similarity scores while preserving original order for equal scores
linksAndScores.sort((a, b) => {
const scoreDiff = b.score - a.score;
return scoreDiff === 0 ? a.originalIndex - b.originalIndex : scoreDiff;
});
// Sanitize search query by removing null characters
const sanitizedQuery = searchQuery;

return linksAndScores;
} catch (error) {
console.error(`Error performing semantic search: ${error}`);
return [];
}
// Generate embeddings for the search query
const queryEmbedding = await getEmbedding(sanitizedQuery);

// Generate embeddings for each link and calculate similarity
const linksAndScores = await Promise.all(linksWithContext.map(async (linkWithContext, index) => {
try {
const linkEmbedding = await getEmbedding(linkWithContext);
const score = cosineSimilarity(queryEmbedding, linkEmbedding);

return {
link: links[index],
linkWithContext,
score,
originalIndex: index
};
} catch (err) {
console.error(`Error generating embedding for link: ${err}`);
// If embedding fails for a link, return with score 0
return {
link: links[index],
linkWithContext,
score: 0,
originalIndex: index
};
}
}));

// Sort links based on similarity scores while preserving original order for equal scores
linksAndScores.sort((a, b) => {
const scoreDiff = b.score - a.score;
return scoreDiff === 0 ? a.originalIndex - b.originalIndex : scoreDiff;
});

return linksAndScores;
} catch (error) {
console.error(`Error performing semantic search: ${error}`);
return [];
}
}

export { performRanking };
Loading