diff --git a/examples/estore/index.ts b/examples/estore/index.ts index dee8458..8ef01f7 100644 --- a/examples/estore/index.ts +++ b/examples/estore/index.ts @@ -155,7 +155,7 @@ async function main() { console.log("Executing database methods..."); console.log('Creating "users" table...'); - const newUsersTable = await db.createTable({ + const newUsersTable = await db.createTable<"users", Database["tables"]["users"]>({ name: "users", columns: { id: { type: "bigint", autoIncrement: true, primaryKey: true }, diff --git a/packages/client/src/database.ts b/packages/client/src/database.ts index e5dcaa5..9df2814 100644 --- a/packages/client/src/database.ts +++ b/packages/client/src/database.ts @@ -23,7 +23,7 @@ export interface DatabaseType { */ export interface DatabaseSchema { name: string; - tables?: { [K in keyof TType["tables"]]: Omit, "name"> }; + tables?: { [K in keyof TType["tables"]]: Omit, "name"> }; } /** @@ -255,7 +255,7 @@ export class Database} A `Table` instance representing the specified table. + * @returns {Table} A `Table` instance representing the specified table. */ table< TType, @@ -263,6 +263,7 @@ export class Database( name: TName, ): Table< + TName, TType extends TableType ? TType : TDatabaseType["tables"][TName] extends TableType @@ -272,6 +273,7 @@ export class Database { return new Table< + TName, TType extends TableType ? TType : TDatabaseType["tables"][TName] extends TableType @@ -307,8 +309,10 @@ export class Database>} A promise that resolves to the created `Table` instance. */ - createTable(schema: TableSchema): Promise> { - return Table.create(this._connection, this.name, schema, this._ai); + createTable( + schema: TableSchema, + ): Promise> { + return Table.create(this._connection, this.name, schema, this._ai); } /** diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index 7878339..c434345 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -4,6 +4,7 @@ import { Workspace, type ConnectWorkspaceConfig, type WorkspaceType } from "./wo export type * from "./types"; export { escape } from "mysql2"; +export { QueryBuilder } from "./query/builder"; /** * Configuration object for initializing a `SingleStoreClient` instance. diff --git a/packages/client/src/query/builder.ts b/packages/client/src/query/builder.ts index 7f9a1aa..311c1ad 100644 --- a/packages/client/src/query/builder.ts +++ b/packages/client/src/query/builder.ts @@ -3,7 +3,12 @@ import { escape } from "mysql2"; import type { DatabaseType } from "../database"; import type { TableType } from "../table"; -export type SelectClause = (keyof TColumn | (string & {}))[]; +export type SelectClause< + TTableName extends string, + TTableType extends TableType, + TDatabaseType extends DatabaseType, + _TTableColumns = TTableType["columns"], +> = ((string & {}) | keyof _TTableColumns)[]; export type WhereOperator = TColumnValue extends string ? { @@ -26,44 +31,65 @@ export type WhereOperator = TColumnValue extends string } : never; -export type WhereClause = { - [K in keyof TColumn]?: WhereOperator | TColumn[K]; -} & { - OR?: WhereClause[]; - NOT?: WhereClause; +export type WhereClause< + TTableName extends string, + TTableType extends TableType, + TDatabaseType extends DatabaseType, + _TTableColumns = TTableType["columns"], +> = { [K in keyof _TTableColumns]?: WhereOperator<_TTableColumns[K]> | _TTableColumns[K] } & { + OR?: WhereClause[]; + NOT?: WhereClause; }; -export type GroupByClause = (keyof TColumn)[]; +export type GroupByClause< + TTableName extends string, + TTableType extends TableType, + TDatabaseType extends DatabaseType, + _TTableColumns = TTableType["columns"], +> = ((string & {}) | keyof _TTableColumns)[]; export type OrderByDirection = "asc" | "desc"; -export type OrderByClause = { - [K in keyof TColumn]?: OrderByDirection; -}; - -export interface QueryBuilderParams { - select?: SelectClause; - where?: WhereClause; - groupBy?: GroupByClause; - orderBy?: OrderByClause; +export type OrderByClause< + TTableName extends string, + TTableType extends TableType, + TDatabaseType extends DatabaseType, + _TTableColumns = TTableType["columns"], +> = { [K in string & {}]: OrderByDirection } & { [K in keyof _TTableColumns]?: OrderByDirection }; + +export interface QueryBuilderParams< + TTableName extends string, + TTableType extends TableType, + TDatabaseType extends DatabaseType, +> { + select?: SelectClause; + where?: WhereClause; + groupBy?: GroupByClause; + orderBy?: OrderByClause; limit?: number; offset?: number; } -export type ExtractQuerySelectedColumn | undefined> = - TParams extends QueryBuilderParams - ? TParams["select"] extends (keyof TColumn)[] - ? Pick - : TColumn - : TColumn; +export type AnyQueryBuilderParams = QueryBuilderParams; + +export type ExtractQuerySelectedColumn< + TTableName extends string, + TDatabaseType extends DatabaseType, + TParams extends AnyQueryBuilderParams | undefined, + _Table extends TDatabaseType["tables"][TTableName] = TDatabaseType["tables"][TTableName], +> = TParams extends AnyQueryBuilderParams + ? TParams["select"] extends (keyof _Table["columns"])[] + ? Pick<_Table["columns"], TParams["select"][number]> + : _Table["columns"] + : _Table["columns"]; -export class QueryBuilder { +export class QueryBuilder { constructor( private _databaseName: string, - private _tableName: string, + private _tableName: TName, ) {} - buildSelectClause(select?: SelectClause) { + buildSelectClause(select?: SelectClause) { const columns = select ? select : ["*"]; return `SELECT ${columns.join(", ")}`; } @@ -97,7 +123,7 @@ export class QueryBuilder): string { + buildWhereClause(conditions?: WhereClause): string { if (!conditions || !Object.keys(conditions).length) return ""; const clauses: string[] = []; @@ -119,11 +145,11 @@ export class QueryBuilder): string { + buildGroupByClause(columns?: GroupByClause): string { return columns?.length ? `GROUP BY ${columns.join(", ")}` : ""; } - buildOrderByClause(clauses?: OrderByClause): string { + buildOrderByClause(clauses?: OrderByClause): string { if (!clauses) return ""; const condition = Object.entries(clauses) @@ -145,7 +171,7 @@ export class QueryBuilder) { + buildClauses>(params?: TParams) { return { select: this.buildSelectClause(params?.select), from: this.buildFromClause(), @@ -157,7 +183,7 @@ export class QueryBuilder) { - return Object.values(this.buildClauses(params)).join(" "); + buildQuery>(params?: TParams) { + return Object.values(this.buildClauses(params)).join(" ").trim(); } } diff --git a/packages/client/src/table.ts b/packages/client/src/table.ts index baea553..a00d5aa 100644 --- a/packages/client/src/table.ts +++ b/packages/client/src/table.ts @@ -18,16 +18,17 @@ export interface TableType { /** * Interface representing the schema of a table, including its columns, primary keys, full-text keys, and additional clauses. * + * @typeParam TName - A type extending `string` that defines the name of the table. * @typeParam TType - A type extending `TableType` that defines the structure of the table. * - * @property {string} name - The name of the table. + * @property {TName} name - The name of the table. * @property {Object} columns - An object where each key is a column name and each value is the schema of that column, excluding the name. * @property {string[]} [primaryKeys] - An optional array of column names that form the primary key. * @property {string[]} [fulltextKeys] - An optional array of column names that form full-text keys. * @property {string[]} [clauses] - An optional array of additional SQL clauses for the table definition. */ -export interface TableSchema { - name: string; +export interface TableSchema { + name: TName; columns: { [K in keyof TType["columns"]]: Omit }; primaryKeys?: string[]; fulltextKeys?: string[]; @@ -87,6 +88,7 @@ type VectorScoreKey = "v_score"; * @property {VectorScoreKey} vScoreKey - The key used for vector scoring in vector search queries, defaulting to `"v_score"`. */ export class Table< + TName extends string = string, TType extends TableType = TableType, TDatabaseType extends DatabaseType = DatabaseType, TAi extends AnyAI | undefined = undefined, @@ -97,7 +99,7 @@ export class Table< constructor( private _connection: Connection, public databaseName: string, - public name: string, + public name: TName, private _ai?: TAi, ) { this._path = [databaseName, name].join("."); @@ -144,7 +146,7 @@ export class Table< * * @returns {string} An SQL string representing the table definition. */ - static schemaToClauses(schema: TableSchema): string { + static schemaToClauses(schema: TableSchema): string { const clauses: string[] = [ ...Object.entries(schema.columns).map(([name, schema]) => { return Column.schemaToClauses({ ...schema, name }); @@ -160,6 +162,7 @@ export class Table< /** * Creates a new table in the database with the specified schema. * + * @typeParam TName - The name of the table, which extends `string`. * @typeParam TType - The type of the table, which extends `TableType`. * @typeParam TDatabaseType - The type of the database, which extends `DatabaseType`. * @typeParam TAi - The type of AI functionalities integrated with the table, which can be undefined. @@ -169,24 +172,25 @@ export class Table< * @param {TableSchema} schema - The schema defining the structure of the table. * @param {TAi} [ai] - Optional AI functionalities to associate with the table. * - * @returns {Promise>} A promise that resolves to the created `Table` instance. + * @returns {Promise>} A promise that resolves to the created `Table` instance. */ static async create< + TName extends string = string, TType extends TableType = TableType, TDatabaseType extends DatabaseType = DatabaseType, TAi extends AnyAI | undefined = undefined, >( connection: Connection, databaseName: string, - schema: TableSchema, + schema: TableSchema, ai?: TAi, - ): Promise> { + ): Promise> { const clauses = Table.schemaToClauses(schema); await connection.client.execute(`\ CREATE TABLE IF NOT EXISTS ${databaseName}.${schema.name} (${clauses}) `); - return new Table(connection, databaseName, schema.name, ai); + return new Table(connection, databaseName, schema.name, ai); } /** @@ -331,13 +335,11 @@ export class Table< * * @param {TParams} params - The arguments defining the query, including selected columns, filters, and other options. * - * @returns {Promise<(ExtractQuerySelectedColumn & RowDataPacket)[]>} A promise that resolves to an array of selected rows. + * @returns {Promise<(ExtractQuerySelectedColumn & RowDataPacket)[]>} A promise that resolves to an array of selected rows. */ - async find | undefined>( - params?: TParams, - ): Promise<(ExtractQuerySelectedColumn & RowDataPacket)[]> { - type SelectedColumn = ExtractQuerySelectedColumn; - const queryBuilder = new QueryBuilder(this.databaseName, this.name); + async find>(params?: TParams) { + type SelectedColumn = ExtractQuerySelectedColumn; + const queryBuilder = new QueryBuilder(this.databaseName, this.name); const query = queryBuilder.buildQuery(params); const [rows] = await this._connection.client.execute<(SelectedColumn & RowDataPacket)[]>(query); return rows; @@ -347,11 +349,14 @@ export class Table< * Updates rows in the table based on the specified values and filters. * * @param {Partial} values - The values to update in the table. - * @param {WhereClause} where - The where clause to apply to the update query. + * @param {WhereClause} where - The where clause to apply to the update query. * * @returns {Promise<[ResultSetHeader, FieldPacket[]]>} A promise that resolves when the update is complete. */ - update(values: Partial, where: WhereClause): Promise<[ResultSetHeader, FieldPacket[]]> { + update( + values: Partial, + where: WhereClause, + ): Promise<[ResultSetHeader, FieldPacket[]]> { const _where = new QueryBuilder(this.databaseName, this.name).buildWhereClause(where); const columnAssignments = Object.keys(values) @@ -365,11 +370,11 @@ export class Table< /** * Deletes rows from the table based on the specified filters. If no filters are provided, the table is truncated. * - * @param {WhereClause} [where] - The where clause to apply to the delete query. + * @param {WhereClause} [where] - The where clause to apply to the delete query. * * @returns {Promise<[ResultSetHeader, FieldPacket[]]>} A promise that resolves when the delete operation is complete. */ - delete(where?: WhereClause): Promise<[ResultSetHeader, FieldPacket[]]> { + delete(where?: WhereClause): Promise<[ResultSetHeader, FieldPacket[]]> { if (!where) return this.truncate(); const _where = new QueryBuilder(this.databaseName, this.name).buildWhereClause(where); const query = `DELETE FROM ${this._path} ${_where}`; @@ -395,7 +400,7 @@ export class Table< * @param {TQueryParams} [queryParams] - Optional query builder parameters to refine the search, such as filters, * groupings, orderings, limits, and offsets. * - * @returns {Promise<(ExtractQuerySelectedColumn & { [K in VectorScoreKey]: number } & RowDataPacket)[]>} + * @returns {Promise<(ExtractQuerySelectedColumn & { v_score: number } & RowDataPacket)[]>} * A promise that resolves to an array of rows matching the vector search criteria, each row including * the selected columns and a vector similarity score. */ @@ -405,17 +410,12 @@ export class Table< vectorColumn: TableColumnName; embeddingParams?: TAi extends AnyAI ? Parameters[1] : never; }, - TQueryParams extends QueryBuilderParams, - >( - params: TParams, - queryParams?: TQueryParams, - ): Promise< - (ExtractQuerySelectedColumn & { [K in VectorScoreKey]: number } & RowDataPacket)[] - > { - type SelectedColumn = ExtractQuerySelectedColumn; + TQueryParams extends QueryBuilderParams, + >(params: TParams, queryParams?: TQueryParams) { + type SelectedColumn = ExtractQuerySelectedColumn; type ResultColumn = SelectedColumn & { [K in VectorScoreKey]: number }; - const clauses = new QueryBuilder(this.databaseName, this.name).buildClauses(queryParams); + const clauses = new QueryBuilder(this.databaseName, this.name).buildClauses(queryParams); const promptEmbedding = (await this.ai.embeddings.create(params.prompt, params.embeddingParams))[0] || []; let orderByClause = `ORDER BY ${this.vScoreKey} DESC`; @@ -464,7 +464,7 @@ export class Table< async createChatCompletion< TParams extends Parameters[0] & (TAi extends AnyAI ? Parameters[0] : never) & { template?: string }, - TQueryParams extends QueryBuilderParams, + TQueryParams extends QueryBuilderParams, >(params: TParams, queryParams?: TQueryParams): Promise> { const { prompt, systemRole, template, vectorColumn, embeddingParams, ...createChatCompletionParams } = params; diff --git a/packages/rag/src/chat/index.ts b/packages/rag/src/chat/index.ts index 2771b39..b4032e4 100644 --- a/packages/rag/src/chat/index.ts +++ b/packages/rag/src/chat/index.ts @@ -60,6 +60,9 @@ export class Chat< TDatabase extends AnyDatabase = AnyDatabase, TAi extends AnyAI = AnyAI, TChatCompletionTool extends AnyChatCompletionTool[] | undefined = undefined, + TTableName extends string = string, + TSessionsTableName extends string = string, + TMessagesTableName extends string = string, > { constructor( private _database: TDatabase, @@ -70,9 +73,9 @@ export class Chat< public name: string, public systemRole: ChatSession["systemRole"], public store: ChatSession["store"], - public tableName: string, - public sessionsTableName: ChatSession["tableName"], - public messagesTableName: ChatSession["messagesTableName"], + public tableName: TTableName, + public sessionsTableName: TSessionsTableName, + public messagesTableName: TMessagesTableName, ) {} /** @@ -89,8 +92,8 @@ export class Chat< private static _createTable( database: TDatabase, name: TName, - ): Promise> { - return database.createTable({ + ): Promise> { + return database.createTable({ name, columns: { id: { type: "bigint", autoIncrement: true, primaryKey: true }, @@ -116,13 +119,22 @@ export class Chat< * @param {TAi} ai - The AI instance used in the chat. * @param {TConfig} [config] - The configuration object for the chat. * - * @returns {Promise>} A promise that resolves to the created `Chat` instance. + * @returns {Promise>} A promise that resolves to the created `Chat` instance. */ static async create( database: TDatabase, ai: TAi, config?: TConfig, - ): Promise> { + ): Promise< + Chat< + TDatabase, + TAi, + TConfig["tools"], + TConfig["tableName"] extends string ? TConfig["tableName"] : string, + TConfig["sessionsTableName"] extends string ? TConfig["sessionsTableName"] : string, + TConfig["messagesTableName"] extends string ? TConfig["messagesTableName"] : string + > + > { const createdAt: Chat["createdAt"] = new Date().toISOString().replace("T", " ").substring(0, 23); const _config: ChatConfig = { @@ -157,7 +169,7 @@ export class Chat< id = rows?.[0].insertId; } - return new Chat( + return new Chat( database, ai, _config.tools, @@ -188,7 +200,7 @@ export class Chat< tableName: Chat["tableName"], sessionsTable: Chat["sessionsTableName"], messagesTableName: Chat["messagesTableName"], - where?: Parameters["delete"]>[0], + where?: Parameters["delete"]>[0], ): Promise<[[ResultSetHeader, FieldPacket[]], [ResultSetHeader, FieldPacket[]][]]> { const table = database.table(tableName); const deletedRowIds = await table.find({ select: ["id"], where }); @@ -206,7 +218,7 @@ export class Chat< * * @returns {Promise<[ResultSetHeader, FieldPacket[]]>} A promise that resolves when the update operation is complete. */ - async update(values: Parameters["update"]>[0]): Promise<[ResultSetHeader, FieldPacket[]]> { + async update(values: Parameters["update"]>[0]): Promise<[ResultSetHeader, FieldPacket[]]> { const result = await this._database.table(this.tableName).update(values, { id: this.id }); for (const [key, value] of Object.entries(values)) { @@ -234,9 +246,19 @@ export class Chat< * * @param {TName} [name] - The name of the session. * - * @returns {Promise>} A promise that resolves to the created `ChatSession` instance. + * @returns {Promise>} A promise that resolves to the created `ChatSession` instance. */ - createSession(name?: TName): Promise> { + createSession( + name?: TName, + ): Promise< + ChatSession< + TDatabase, + TAi, + TChatCompletionTool, + TSessionsTableName extends string ? TSessionsTableName : string, + TMessagesTableName extends string ? TMessagesTableName : string + > + > { return ChatSession.create(this._database, this._ai, { chatId: this.id, name, @@ -249,18 +271,18 @@ export class Chat< } /** - * Selects chat sessions from the current chat based on the provided filters and options. + * Finds chat sessions from the current chat based on the provided filters and options. * - * @typeParam T - The parameters passed to the `select` method of the `Table` class. + * @typeParam T - The parameters passed to the `find` method of the `Table` class. * - * @param {...T} args - The arguments defining the filters and options for selecting sessions. + * @param {params} params - The parameters defining the filters and options for finding sessions. * - * @returns {Promise[]>} A promise that resolves to an array of `ChatSession` instances representing the selected sessions. + * @returns {Promise[]>} A promise that resolves to an array of `ChatSession` instances representing the found sessions. */ - async selectSessions>["find"]>>( - ...args: T - ): Promise[]> { - const rows = await this._database.table(this.sessionsTableName).find(...args); + async findSessions( + params?: Parameters>["find"]>[0], + ): Promise[]> { + const rows = await this._database.table(this.sessionsTableName).find(params); return rows.map( (row) => diff --git a/packages/rag/src/chat/message.ts b/packages/rag/src/chat/message.ts index a9b554e..5b4f0e0 100644 --- a/packages/rag/src/chat/message.ts +++ b/packages/rag/src/chat/message.ts @@ -37,7 +37,7 @@ export interface ChatMessagesTable { * @property {boolean} store - Whether the message is stored in the database. * @property {string} tableName - The name of the table where the message is stored. */ -export class ChatMessage { +export class ChatMessage { constructor( private _database: TDatabase, public id: number | undefined, @@ -46,7 +46,7 @@ export class ChatMessage { public role: ChatCompletionMessage["role"], public content: ChatCompletionMessage["content"], public store: boolean, - public tableName: string, + public tableName: TTableName, ) {} /** @@ -63,8 +63,8 @@ export class ChatMessage { static createTable( database: TDatabase, name: TName, - ): Promise> { - return database.createTable({ + ): Promise> { + return database.createTable({ name, columns: { id: { type: "bigint", autoIncrement: true, primaryKey: true }, @@ -90,7 +90,7 @@ export class ChatMessage { static async create( database: TDatabase, config: TConfig, - ): Promise> { + ): Promise> { const { sessionId, role, content, store, tableName } = config; const createdAt: ChatMessage["createdAt"] = new Date().toISOString().replace("T", " ").substring(0, 23); let id: ChatMessage["id"]; @@ -115,7 +115,7 @@ export class ChatMessage { static delete( database: AnyDatabase, tableName: ChatMessage["tableName"], - where?: Parameters["delete"]>[0], + where?: Parameters["delete"]>[0], ): Promise<[ResultSetHeader, FieldPacket[]]> { return database.table(tableName).delete(where); } @@ -127,7 +127,9 @@ export class ChatMessage { * * @returns {Promise<[ResultSetHeader, FieldPacket[]]>} A promise that resolves when the update operation is complete. */ - async update(values: Parameters["update"]>[0]): Promise<[ResultSetHeader, FieldPacket[]]> { + async update( + values: Parameters["update"]>[0], + ): Promise<[ResultSetHeader, FieldPacket[]]> { const result = await this._database.table(this.tableName).update(values, { id: this.id }); for (const [key, value] of Object.entries(values)) { diff --git a/packages/rag/src/chat/session.ts b/packages/rag/src/chat/session.ts index c4cdc34..5e95a43 100644 --- a/packages/rag/src/chat/session.ts +++ b/packages/rag/src/chat/session.ts @@ -70,6 +70,8 @@ export class ChatSession< TDatabase extends AnyDatabase = AnyDatabase, TAi extends AnyAI = AnyAI, TChatCompletionTool extends AnyChatCompletionTool[] | undefined = undefined, + TTableName extends string = string, + TMessagesTableName extends string = string, > { constructor( private _database: TDatabase, @@ -81,8 +83,8 @@ export class ChatSession< public name: string, public systemRole: string, public store: ChatMessage["store"], - public tableName: string, - public messagesTableName: ChatMessage["tableName"], + public tableName: TTableName, + public messagesTableName: TMessagesTableName, ) {} /** @@ -99,8 +101,8 @@ export class ChatSession< static createTable( database: TDatabase, name: TName, - ): Promise> { - return database.createTable({ + ): Promise> { + return database.createTable({ name, columns: { id: { type: "bigint", autoIncrement: true, primaryKey: true }, @@ -122,13 +124,21 @@ export class ChatSession< * @param {TAi} ai - The AI instance used in the chat session. * @param {TConfig} [config] - The configuration object for the chat session. * - * @returns {Promise>} A promise that resolves to the created `ChatSession` instance. + * @returns {Promise>} A promise that resolves to the created `ChatSession` instance. */ static async create>( database: TDatabase, ai: TAi, config?: TConfig, - ): Promise> { + ): Promise< + ChatSession< + TDatabase, + TAi, + TConfig["tools"], + TConfig["tableName"] extends string ? TConfig["tableName"] : string, + TConfig["messagesTableName"] extends string ? TConfig["messagesTableName"] : string + > + > { const createdAt: ChatSession["createdAt"] = new Date().toISOString().replace("T", " ").substring(0, 23); const _config: ChatSessionConfig = { @@ -150,7 +160,7 @@ export class ChatSession< id = rows?.[0].insertId; } - return new ChatSession( + return new ChatSession( database, ai, _config.tools, @@ -179,7 +189,7 @@ export class ChatSession< database: AnyDatabase, tableName: ChatSession["tableName"], messagesTableName: ChatSession["messagesTableName"], - where?: Parameters["delete"]>[0], + where?: Parameters["delete"]>[0], ): Promise<[ResultSetHeader, FieldPacket[]][]> { const table = database.table(tableName); const deletedRowIds = await table.find({ select: ["id"], where }); @@ -197,7 +207,9 @@ export class ChatSession< * * @returns {Promise<[ResultSetHeader, FieldPacket[]]>} A promise that resolves when the update operation is complete. */ - async update(values: Parameters["update"]>[0]): Promise<[ResultSetHeader, FieldPacket[]]> { + async update( + values: Parameters["update"]>[0], + ): Promise<[ResultSetHeader, FieldPacket[]]> { const result = await this._database.table(this.tableName).update(values, { id: this.id }); for (const [key, value] of Object.entries(values)) { @@ -227,12 +239,12 @@ export class ChatSession< * @param {TRole} role - The role of the message sender. * @param {TContent} content - The content of the chat message. * - * @returns {Promise>} A promise that resolves to the created `ChatMessage` instance. + * @returns {Promise>} A promise that resolves to the created `ChatMessage` instance. */ createMessage( role: TRole, content: TContent, - ): Promise> { + ): Promise> { return ChatMessage.create(this._database, { sessionId: this.id, role, @@ -243,18 +255,16 @@ export class ChatSession< } /** - * Selects chat messages from the current session based on the provided filters and options. - * - * @typeParam T - The parameters passed to the `select` method of the `Table` class. + * Finds chat messages from the current session based on the provided filters and options. * - * @param {...T} args - The arguments defining the filters and options for selecting messages. + * @param {params} params - The parameters defining the filters and options for finding messages. * - * @returns {Promise[]>} A promise that resolves to an array of `ChatMessage` instances representing the selected messages. + * @returns {Promise[]>} A promise that resolves to an array of `ChatMessage` instances representing the found messages. */ - async selectMessages>["find"]>>( - ...args: T - ): Promise[]> { - const rows = await this._database.table(this.messagesTableName).find(...args); + async findMessages( + params?: Parameters>["find"]>[0], + ): Promise[]> { + const rows = await this._database.table(this.messagesTableName).find(params); return rows.map((row) => { return new ChatMessage( @@ -320,7 +330,7 @@ export class ChatSession< if (loadDatabaseSchema || loadHistory) { const [databaseSchema, historyMessages] = await Promise.all([ loadDatabaseSchema ? this._database.describe() : undefined, - loadHistory ? this.selectMessages({ orderBy: { createdAt: "asc" } }) : undefined, + loadHistory ? this.findMessages({ orderBy: { createdAt: "asc" } }) : undefined, ]); if (databaseSchema) { diff --git a/packages/rag/src/index.ts b/packages/rag/src/index.ts index 478e947..2472392 100644 --- a/packages/rag/src/index.ts +++ b/packages/rag/src/index.ts @@ -1,14 +1,5 @@ import type { AnyAI, AnyChatCompletionTool } from "@singlestore/ai"; -import type { - AnyDatabase, - Database, - DatabaseType, - FieldPacket, - InferDatabaseType, - QueryBuilderParams, - ResultSetHeader, - Table, -} from "@singlestore/client"; +import type { AnyDatabase, FieldPacket, InferDatabaseType, ResultSetHeader, Table } from "@singlestore/client"; import { Chat, type CreateChatConfig, type ChatsTable } from "./chat"; @@ -63,31 +54,56 @@ export class RAG>} A promise that resolves to the created `Chat` instance. + * @returns {Promise>} A promise that resolves to the created `Chat` instance. */ - createChat(config?: TConfig): Promise> { + createChat( + config?: TConfig, + ): Promise< + Chat< + TDatabase, + TAi, + TConfig["tools"], + TConfig["tableName"] extends string ? TConfig["tableName"] : string, + TConfig["sessionsTableName"] extends string ? TConfig["sessionsTableName"] : string, + TConfig["messagesTableName"] extends string ? TConfig["messagesTableName"] : string + > + > { return Chat.create(this._database, this._ai, config); } /** - * Selects chat instances from the database based on the provided configuration and arguments. + * Finds chat instances from the database based on the provided configuration and parameters. * - * @typeParam TConfig - The configuration object for selecting chats. - * @typeParam TSelectArgs - The parameters passed to the `select` method of the `Table` class. + * @typeParam TConfig - The configuration object for finding chats. * - * @param {TConfig} [config] - The configuration object for selecting chats. - * @param {...TSelectArgs} selectArgs - The arguments defining the filters and options for selecting chats. + * @param {TConfig} [config] - The configuration object for finding chats. + * @param {findParams} findParams - The parameters defining the filters and options for finding chats. * - * @returns {Promise[]>} A promise that resolves to an array of `Chat` instances representing the selected chats. + * @returns {Promise[]>} A promise that resolves to an array of `Chat` instances representing the found chats. */ async findChats( config?: TConfig, - findParams?: QueryBuilderParams>, - ): Promise[]> { + findParams?: Parameters< + Table< + TConfig["tableName"] extends string ? TConfig["tableName"] : string, + ChatsTable, + InferDatabaseType + >["find"] + >[0], + ): Promise< + Chat< + TDatabase, + TAi, + TConfig["tools"], + TConfig["tableName"] extends string ? TConfig["tableName"] : string, + string, + string + >[] + > { const rows = await this._database.table(config?.tableName || "chats").find(findParams); return rows.map( (row) => - new Chat( + new Chat( this._database, this._ai, config?.tools || [],