Skip to content

Commit

Permalink
Merge pull request #51 from nomic-ai/nom-1635-vector-search-in-js
Browse files Browse the repository at this point in the history
Nom 1635 vector search in js
  • Loading branch information
bmschmidt authored Jun 14, 2024
2 parents 5ada315 + 45fad75 commit c1c5317
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 25 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/npm-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Publish @next release to npm
on:
push:
branches:
- main

permissions:
contents: write
packages: write
deployments: write

jobs:
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-node@v2
with:
node-version: '20.x'
registry-url: 'https://registry.npmjs.org'
- run: npm ci
- name: Configure Git user
run: |
git config --global user.email "[email protected]"
git config --global user.name "GitHub Actions"
- run: npm version prerelease --preid=next
- name: Commit bumped version
run: |
# git add package.json package-lock.json
# git commit -m "Bump version to $(node -p "require('./package.json').version")"
git push
- run: npm publish --tag next
env:
NODE_AUTH_TOKEN: ${{ secrets.NODE_AUTH_TOKEN }}
4 changes: 4 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 0.10.0

- Add support for nearest-neighbor search by vector.

# 0.9.6

- Rename "AtlasProject" to "AtlasDataset" with backwards compatible alias.
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@nomic-ai/atlas",
"version": "0.9.5",
"version": "0.10.0",
"type": "module",
"files": [
"dist"
Expand Down
22 changes: 11 additions & 11 deletions src/api-raw-types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1693,7 +1693,7 @@ export interface components {
*/
model:
| components['schemas']['NomicTextEmbeddingModel']
| components['schemas']['NomicVisionEmbeddingModel'];
| components['schemas']['NomicImageEmbeddingModel'];
/**
* Tokens
* @description The total tokens used.
Expand Down Expand Up @@ -1742,9 +1742,9 @@ export interface components {
atlas_index_id: string;
/**
* Queries
* @description The bytes of a batch of embeddings to get neighbors for
* @description A set of embeddings to query. Where `n` is the number of vectors to search and `d` is the vector dimensionality, this can be either an `n`x`d` numpy array encoded to base64, OR a list of `n` lists with `d` numbers per list.
*/
queries: string;
queries: string | number[][];
/**
* K
* @description The number of neighbors to return
Expand Down Expand Up @@ -2332,6 +2332,14 @@ export interface components {
*/
atom_strategies: string[];
};
/**
* NomicImageEmbeddingModel
* @description An enumeration.
* @enum {unknown}
*/
NomicImageEmbeddingModel:
| 'nomic-embed-vision-v1'
| 'nomic-embed-vision-v1.5';
/**
* NomicProjectModel
* @description An enumeration.
Expand All @@ -2348,14 +2356,6 @@ export interface components {
| 'nomic-embed-text-v1'
| 'nomic-embed-text-v1.5'
| 'nomic-embed-code';
/**
* NomicVisionEmbeddingModel
* @description An enumeration.
* @enum {unknown}
*/
NomicVisionEmbeddingModel:
| 'nomic-embed-vision-v1'
| 'nomic-embed-vision-v1.5';
/** ObtainAccessTokenRequest */
ObtainAccessTokenRequest: {
/**
Expand Down
26 changes: 26 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { AtlasUser } from './user.js';
import { AtlasProjection } from './projection.js';
import { AtlasDataset as AtlasDataset } from './project.js';
import type { Table } from 'apache-arrow';
import type { components } from 'api-raw-types.js';

type IndexInitializationOptions = {
project_id?: Atlas.UUID;
Expand Down Expand Up @@ -113,4 +114,29 @@ export class AtlasIndex extends BaseAtlasClass {
)) as Table;
return tb;
}

/**
*
* @param param0 A keyed dictionary including `k` (the number of neighbors to return)
* and `queries` (a list of vectors to search for).
* @returns
*/
async nearest_neighbors_by_vector({
k = 10,
queries,
}: Omit<
components['schemas']['EmbeddingNeighborRequest'],
'atlas_index_id'
>): Promise<components['schemas']['EmbeddingNeighborResponse']> {
const { neighbors, distances } = (await this.apiCall(
`/v1/project/data/get/nearest_neighbors/by_embedding`,
'POST',
{
atlas_index_id: this.id,
k,
queries,
}
)) as components['schemas']['EmbeddingNeighborResponse'];
return { neighbors, distances };
}
}
24 changes: 13 additions & 11 deletions src/project.ts
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,19 @@ export class AtlasDataset extends BaseAtlasClass {
* @param ids A list of identifiers to fetch from the server.
*/

async fetch_ids(ids?: string[]): Promise<Record<string, any>[]> {
throw new Error('Not implemented');
async fetch_ids(
ids?: string[]
): Promise<Record<string, Record<string, any>>> {
if (ids === undefined) {
return {};
}
const response = await this.apiCall(
'/v1/project/data/get',
'POST',
{ project_id: this.id, datum_ids: ids },
null
);
return response as Record<string, Record<string, any>>;
}

async createIndex(
Expand Down Expand Up @@ -285,15 +296,6 @@ export class AtlasDataset extends BaseAtlasClass {
return new AtlasIndex(id, this.user, { project: this });
}

async delete_data(ids: string[]): Promise<void> {
// TODO: untested
// const info = await this.info
await this.user.apiCall('/v1/project/data/delete', 'POST', {
project_id: this.id,
datum_ids: ids,
});
}

validate_metadata(): void {
// validate metadata
}
Expand Down
35 changes: 35 additions & 0 deletions src/projection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { BaseAtlasClass } from './user.js';
import type { AtlasUser } from './user.js';
import { AtlasDataset } from './project.js';
import type { AtlasIndex } from './index.js';
import { components } from 'api-raw-types.js';

type UUID = string;

Expand Down Expand Up @@ -300,6 +301,40 @@ export class AtlasProjection extends BaseAtlasClass {
return `${protocol}://${this.user.apiLocation}/v1/project/${this.project_id}/index/projection/${this.id}/quadtree`;
}

/**
*
* @param param0 an object with keys k (number of numbers) and queries (list of vectors, where each one is the length of the embedding space).
* @returns A list of entries in sorted order, where each entry is a list of neighbors including distances in the `_distance` field.
*/
async nearest_neighbors_by_vector({
k = 10,
queries,
}: Omit<
components['schemas']['EmbeddingNeighborRequest'],
'atlas_index_id'
>): Promise<Record<string, any>> {
const index = await this.index();
const { neighbors, distances } = await index.nearest_neighbors_by_vector({
k,
queries,
});
const project = await this.project();
const datums = (await Promise.all(
neighbors.map((ids) => project.fetch_ids(ids).then((d) => d.datums))
)) as Record<string, any>[][];
const filled_out: Record<string, any>[][] = [];
for (let i = 0; i < neighbors.length; i++) {
filled_out[i] = [];
for (let j = 0; j < neighbors[i].length; j++) {
const d = { ...datums[i][j] };
d._distance = distances[i][j];
filled_out[i].push(d);
}
}

return filled_out;
}

async info() {
if (this._info !== undefined) {
return this._info;
Expand Down
1 change: 0 additions & 1 deletion src/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ export class AtlasUser {
Record<string, any> | string | Array<any> | Table | Uint8Array | null
> {
// make an API call

if (headers === null) {
const credentials = await this.credentials;
if (credentials === null) {
Expand Down
23 changes: 23 additions & 0 deletions tests/neighbors.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { test } from 'uvu';
import { AtlasProjection } from '../dist/projection.js';
import { AtlasUser } from '../dist/user.js';
import * as assert from 'uvu/assert';

test('Neighbors', async () => {
// get user
const user = new AtlasUser({ useEnvToken: true });
const projection = new AtlasProjection(
'0efb002a-09b3-47df-b43e-71780879b501',
user,
{ project_id: 'b7d7ff07-7272-4481-8618-c05bcf6feca5' }
);
const vec = [];
for (let i = 0; i < 768; i++) {
vec.push(Math.random());
}
const result = await projection.nearest_neighbors_by_vector({
queries: [vec],
k: 25,
});
assert.is(result[0].length, 25);
});
7 changes: 6 additions & 1 deletion tests/user.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ import { AtlasOrganization } from '../dist/organization.js';

test('AtlasOrganization test', async () => {
const user = new AtlasUser({ useEnvToken: true });
const info = await user.info();

const info = await user.info().catch((err) => {
console.error(err);
throw err;
});

const organization = new AtlasOrganization(
info.organizations[0].organization_id,
user
Expand Down

0 comments on commit c1c5317

Please sign in to comment.