Skip to content

Commit

Permalink
refactor: clean gemini embedding (#781)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusschiesser authored Apr 30, 2024
1 parent 0d50b22 commit 130b799
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 23 deletions.
4 changes: 2 additions & 2 deletions examples/gemini.ts → examples/gemini/embedding.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { GEMINI_MODEL, GeminiEmbedding } from "llamaindex";
import { GEMINI_EMBEDDING_MODEL, GeminiEmbedding } from "llamaindex";

async function main() {
if (!process.env.GOOGLE_API_KEY) {
throw new Error("Please set the GOOGLE_API_KEY environment variable.");
}
const embedModel = new GeminiEmbedding({
model: GEMINI_MODEL.GEMINI_PRO,
model: GEMINI_EMBEDDING_MODEL.EMBEDDING_001,
});
const texts = ["hello", "world"];
const embeddings = await embedModel.getTextEmbeddingsBatch(texts);
Expand Down
24 changes: 9 additions & 15 deletions packages/core/src/embeddings/GeminiEmbedding.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
import {
GEMINI_MODEL,
GeminiSessionStore,
type GeminiConfig,
type GeminiSession,
} from "../llm/gemini.js";
import { GeminiSessionStore, type GeminiSession } from "../llm/gemini.js";
import { BaseEmbedding } from "./types.js";

export enum GEMINI_EMBEDDING_MODEL {
EMBEDDING_001 = "embedding-001",
TEXT_EMBEDDING_004 = "text-embedding-004",
}

/**
* GeminiEmbedding is an alias for Gemini that implements the BaseEmbedding interface.
*/
export class GeminiEmbedding extends BaseEmbedding {
model: GEMINI_MODEL;
temperature: number;
topP: number;
maxTokens?: number;
model: GEMINI_EMBEDDING_MODEL;
session: GeminiSession;

constructor(init?: GeminiConfig) {
constructor(init?: Partial<GeminiEmbedding>) {
super();
this.model = init?.model ?? GEMINI_MODEL.GEMINI_PRO;
this.temperature = init?.temperature ?? 0.1;
this.topP = init?.topP ?? 1;
this.maxTokens = init?.maxTokens ?? undefined;
this.model = init?.model ?? GEMINI_EMBEDDING_MODEL.EMBEDDING_001;
this.session = init?.session ?? GeminiSessionStore.get();
}

Expand Down
6 changes: 0 additions & 6 deletions packages/core/src/llm/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ type GeminiSessionOptions = {
export enum GEMINI_MODEL {
GEMINI_PRO = "gemini-pro",
GEMINI_PRO_VISION = "gemini-pro-vision",
EMBEDDING_001 = "embedding-001",
AQA = "aqa",
GEMINI_PRO_LATEST = "gemini-1.5-pro-latest",
}

Expand All @@ -44,16 +42,12 @@ export interface GeminiModelInfo {
export const GEMINI_MODEL_INFO_MAP: Record<GEMINI_MODEL, GeminiModelInfo> = {
[GEMINI_MODEL.GEMINI_PRO]: { contextWindow: 30720 },
[GEMINI_MODEL.GEMINI_PRO_VISION]: { contextWindow: 12288 },
[GEMINI_MODEL.EMBEDDING_001]: { contextWindow: 2048 },
[GEMINI_MODEL.AQA]: { contextWindow: 7168 },
[GEMINI_MODEL.GEMINI_PRO_LATEST]: { contextWindow: 10 ** 6 },
};

const SUPPORT_TOOL_CALL_MODELS: GEMINI_MODEL[] = [
GEMINI_MODEL.GEMINI_PRO,
GEMINI_MODEL.GEMINI_PRO_VISION,
GEMINI_MODEL.EMBEDDING_001,
GEMINI_MODEL.AQA,
];

const DEFAULT_GEMINI_PARAMS = {
Expand Down

0 comments on commit 130b799

Please sign in to comment.