-
Notifications
You must be signed in to change notification settings - Fork 373
/
JinaAIEmbedding.ts
155 lines (140 loc) · 4.67 KB
/
JinaAIEmbedding.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import { MultiModalEmbedding } from "@llamaindex/core/embeddings";
import { getEnv } from "@llamaindex/env";
import { imageToDataUrl } from "../internal/utils.js";
import type { ImageType } from "../Node.js";
function isLocal(url: ImageType): boolean {
if (url instanceof Blob) return true;
return new URL(url).protocol === "file:";
}
type TaskType =
| "retrieval.passage"
| "retrieval.query"
| "separation"
| "classification"
| "text-matching";
type EncodingType = "float" | "binary" | "ubinary";
export type JinaEmbeddingRequest = {
input: Array<{ text: string } | { url: string } | { bytes: string }>;
model?: string;
encoding_type?: EncodingType;
task?: TaskType;
dimensions?: number;
late_chunking?: boolean;
};
export type JinaEmbeddingResponse = {
model: string;
object: string;
usage: {
total_tokens: number;
prompt_tokens: number;
};
data: Array<{
object: string;
index: number;
embedding: number[];
}>;
};
const JINA_MULTIMODAL_MODELS = ["jina-clip-v1"];
export class JinaAIEmbedding extends MultiModalEmbedding {
apiKey: string;
model: string;
baseURL: string;
task?: TaskType | undefined;
encodingType?: EncodingType | undefined;
dimensions?: number | undefined;
late_chunking?: boolean | undefined;
async getTextEmbedding(text: string): Promise<number[]> {
const result = await this.getJinaEmbedding({ input: [{ text }] });
return result.data[0]!.embedding;
}
async getImageEmbedding(image: ImageType): Promise<number[]> {
const img = await this.getImageInput(image);
const result = await this.getJinaEmbedding({ input: [img] });
return result.data[0]!.embedding;
}
// Retrieve multiple text embeddings in a single request
getTextEmbeddings = async (texts: string[]): Promise<Array<number[]>> => {
const input = texts.map((text) => ({ text }));
const result = await this.getJinaEmbedding({ input });
return result.data.map((d) => d.embedding);
};
// Retrieve multiple image embeddings in a single request
async getImageEmbeddings(images: ImageType[]): Promise<number[][]> {
const input = await Promise.all(
images.map((img) => this.getImageInput(img)),
);
const result = await this.getJinaEmbedding({ input });
return result.data.map((d) => d.embedding);
}
constructor(init?: Partial<JinaAIEmbedding>) {
super();
const apiKey = init?.apiKey ?? getEnv("JINAAI_API_KEY");
if (!apiKey) {
throw new Error(
"Set Jina AI API Key in JINAAI_API_KEY env variable. Get one for free or top up your key at https://jina.ai/embeddings",
);
}
this.apiKey = apiKey;
this.model = init?.model ?? "jina-embeddings-v3";
this.baseURL = init?.baseURL ?? "https://api.jina.ai/v1/embeddings";
if (init?.embedBatchSize) {
this.embedBatchSize = init.embedBatchSize;
}
this.task = init?.task;
this.encodingType = init?.encodingType;
this.dimensions = init?.dimensions;
this.late_chunking = init?.late_chunking;
}
private async getImageInput(
image: ImageType,
): Promise<{ bytes: string } | { url: string }> {
if (isLocal(image) || image instanceof Blob) {
const base64 = await imageToDataUrl(image);
const bytes = base64.split(",")[1]!;
return { bytes };
} else {
return { url: image.toString() };
}
}
private async getJinaEmbedding(
params: JinaEmbeddingRequest,
): Promise<JinaEmbeddingResponse> {
// if input includes image, check if model supports multimodal embeddings
if (
params.input.some((i) => "url" in i || "bytes" in i) &&
!JINA_MULTIMODAL_MODELS.includes(this.model)
) {
throw new Error(
`Model ${this.model} does not support image embeddings. Use ${JINA_MULTIMODAL_MODELS.join(", ")}`,
);
}
const response = await fetch(this.baseURL, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
},
body: JSON.stringify({
model: this.model,
encoding_type: this.encodingType ?? "float",
...(this.task && { task: this.task }),
...(this.dimensions !== undefined && { dimensions: this.dimensions }),
...(this.late_chunking !== undefined && {
late_chunking: this.late_chunking,
}),
...params,
}),
});
if (!response.ok) {
const reason = await response.text();
throw new Error(
`Request failed with status ${response.status}: ${reason}`,
);
}
const result: JinaEmbeddingResponse = await response.json();
return {
...result,
data: result.data.sort((a, b) => a.index - b.index), // Sort resulting embeddings by index
};
}
}