From cf4437ff436a0a00c85d0d9eeaac541a6c3d466e Mon Sep 17 00:00:00 2001 From: demenskyi Date: Wed, 28 Aug 2024 13:33:35 -0700 Subject: [PATCH 1/2] Join support to the QueryBuilder added --- packages/client/src/database.ts | 12 +- packages/client/src/index.ts | 1 + packages/client/src/query/builder.ts | 176 ++++++++++++++++++++++----- packages/client/src/table.ts | 57 +++++---- 4 files changed, 188 insertions(+), 58 deletions(-) diff --git a/packages/client/src/database.ts b/packages/client/src/database.ts index e5dcaa5..9df2814 100644 --- a/packages/client/src/database.ts +++ b/packages/client/src/database.ts @@ -23,7 +23,7 @@ export interface DatabaseType { */ export interface DatabaseSchema { name: string; - tables?: { [K in keyof TType["tables"]]: Omit, "name"> }; + tables?: { [K in keyof TType["tables"]]: Omit, "name"> }; } /** @@ -255,7 +255,7 @@ export class Database} A `Table` instance representing the specified table. + * @returns {Table} A `Table` instance representing the specified table. */ table< TType, @@ -263,6 +263,7 @@ export class Database( name: TName, ): Table< + TName, TType extends TableType ? TType : TDatabaseType["tables"][TName] extends TableType @@ -272,6 +273,7 @@ export class Database { return new Table< + TName, TType extends TableType ? TType : TDatabaseType["tables"][TName] extends TableType @@ -307,8 +309,10 @@ export class Database>} A promise that resolves to the created `Table` instance. */ - createTable(schema: TableSchema): Promise> { - return Table.create(this._connection, this.name, schema, this._ai); + createTable( + schema: TableSchema, + ): Promise> { + return Table.create(this._connection, this.name, schema, this._ai); } /** diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index 7878339..c434345 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -4,6 +4,7 @@ import { Workspace, type ConnectWorkspaceConfig, type WorkspaceType } from "./wo export type * from "./types"; export { escape } from "mysql2"; +export { QueryBuilder } from "./query/builder"; /** * Configuration object for initializing a `SingleStoreClient` instance. diff --git a/packages/client/src/query/builder.ts b/packages/client/src/query/builder.ts index 7f9a1aa..08f90c0 100644 --- a/packages/client/src/query/builder.ts +++ b/packages/client/src/query/builder.ts @@ -1,9 +1,22 @@ import { escape } from "mysql2"; import type { DatabaseType } from "../database"; -import type { TableType } from "../table"; -export type SelectClause = (keyof TColumn | (string & {}))[]; +export type SelectClause< + TTableName extends string, + TDatabaseType extends DatabaseType, + TJoin extends AnyJoinClauseRecord[] | undefined = undefined, + _TTableColumns = TDatabaseType["tables"][TTableName]["columns"], +> = ( + | (string & {}) + | (TJoin extends AnyJoinClauseRecord[] + ? + | `${TTableName}.${Extract}` + | { + [K in TJoin[number] as K["as"]]: `${K["as"]}.${Extract}`; + }[TJoin[number]["as"]] + : keyof _TTableColumns) +)[]; export type WhereOperator = TColumnValue extends string ? { @@ -26,50 +39,131 @@ export type WhereOperator = TColumnValue extends string } : never; -export type WhereClause = { - [K in keyof TColumn]?: WhereOperator | TColumn[K]; -} & { - OR?: WhereClause[]; - NOT?: WhereClause; +export type WhereClause< + TTableName extends string, + TDatabaseType extends DatabaseType, + TJoin extends AnyJoinClauseRecord[] | undefined = undefined, + _TTableColumns = TDatabaseType["tables"][TTableName]["columns"], +> = (TJoin extends AnyJoinClauseRecord[] + ? { + [K in keyof _TTableColumns as `${TTableName}.${Extract}`]?: + | WhereOperator<_TTableColumns[K]> + | _TTableColumns[K]; + } & { + [K in TJoin[number] as K["as"]]: { + [C in keyof TDatabaseType["tables"][K["table"]]["columns"] as `${K["as"]}.${Extract}`]?: + | WhereOperator + | TDatabaseType["tables"][K["table"]]["columns"][C]; + }; + }[TJoin[number]["as"]] + : { [K in keyof _TTableColumns]?: WhereOperator<_TTableColumns[K]> | _TTableColumns[K] }) & { + OR?: WhereClause[]; + NOT?: WhereClause; }; -export type GroupByClause = (keyof TColumn)[]; +export type GroupByClause< + TTableName extends string, + TDatabaseType extends DatabaseType, + TJoin extends AnyJoinClauseRecord[] | undefined = undefined, + _TTableColumns = TDatabaseType["tables"][TTableName]["columns"], +> = ( + | (string & {}) + | (TJoin extends AnyJoinClauseRecord[] + ? + | `${TTableName}.${Extract}` + | { + [K in TJoin[number] as K["as"]]: `${K["as"]}.${Extract}`; + }[TJoin[number]["as"]] + : keyof _TTableColumns) +)[]; export type OrderByDirection = "asc" | "desc"; -export type OrderByClause = { - [K in keyof TColumn]?: OrderByDirection; -}; +export type OrderByClause< + TTableName extends string, + TDatabaseType extends DatabaseType, + TJoin extends AnyJoinClauseRecord[] | undefined = undefined, + _TTableColumns = TDatabaseType["tables"][TTableName]["columns"], +> = { [K in string & {}]: OrderByDirection } & (TJoin extends AnyJoinClauseRecord[] + ? { [K in keyof _TTableColumns as `${TTableName}.${Extract}`]?: OrderByDirection } & { + [K in TJoin[number] as K["as"]]: { + [C in keyof TDatabaseType["tables"][K["table"]]["columns"] as `${K["as"]}.${Extract}`]?: OrderByDirection; + }; + }[TJoin[number]["as"]] + : { [K in keyof _TTableColumns]?: OrderByDirection }); + +type JoinType = "INNER" | "LEFT" | "RIGHT" | "FULL"; +type JoinOperator = "=" | "<" | ">" | "<=" | ">=" | "!=" | "<=>"; + +export interface JoinClause< + TTableName extends string, + TDatabaseType extends DatabaseType, + TTable extends string, + TAs extends string, +> { + type?: JoinType; + table: TTable; + as: TAs; + on: [ + left: keyof TDatabaseType["tables"][TTable]["columns"], + operator: JoinOperator, + right: (string & {}) | `${TTableName}.${Extract}`, + ]; +} + +export type JoinClauseRecord = { + [K in keyof TDatabaseType["tables"]]: { + [Alias in TAs]: JoinClause, Alias>; + }[TAs]; +}[keyof TDatabaseType["tables"]]; + +export type AnyJoinClauseRecord = JoinClauseRecord; -export interface QueryBuilderParams { - select?: SelectClause; - where?: WhereClause; - groupBy?: GroupByClause; - orderBy?: OrderByClause; +export interface QueryBuilderParams< + TTableName extends string, + TDatabaseType extends DatabaseType, + TJoin extends AnyJoinClauseRecord[] | undefined = undefined, +> { + select?: SelectClause; + where?: WhereClause; + groupBy?: GroupByClause; + orderBy?: OrderByClause; limit?: number; offset?: number; + join?: TJoin; } -export type ExtractQuerySelectedColumn | undefined> = - TParams extends QueryBuilderParams +export type ExtractQuerySelectedColumn | undefined> = + TParams extends QueryBuilderParams ? TParams["select"] extends (keyof TColumn)[] ? Pick : TColumn : TColumn; -export class QueryBuilder { +export class QueryBuilder { constructor( private _databaseName: string, - private _tableName: string, + private _tableName: TName, ) {} - buildSelectClause(select?: SelectClause) { - const columns = select ? select : ["*"]; + buildSelectClause(select?: SelectClause, join?: AnyJoinClauseRecord[]) { + let columns = select ? select : ["*"]; + + if (join?.length) { + columns = columns.map((column) => { + if (String(column).includes(".")) { + const [tableName, columnName] = String(column).split("."); + return `${String(column)} AS ${tableName}_${columnName}`; + } + return column; + }); + } + return `SELECT ${columns.join(", ")}`; } - buildFromClause() { - return `FROM ${this._databaseName}.${String(this._tableName)}`; + buildFromClause(join?: AnyJoinClauseRecord[]) { + return `FROM ${this._databaseName}.${String(this._tableName)}${join?.length ? ` AS ${this._tableName}` : ""}`; } buildWhereCondition(column: string, operator: string, value: any): string { @@ -97,7 +191,7 @@ export class QueryBuilder): string { + buildWhereClause(conditions?: WhereClause): string { if (!conditions || !Object.keys(conditions).length) return ""; const clauses: string[] = []; @@ -119,11 +213,11 @@ export class QueryBuilder): string { + buildGroupByClause(columns?: GroupByClause): string { return columns?.length ? `GROUP BY ${columns.join(", ")}` : ""; } - buildOrderByClause(clauses?: OrderByClause): string { + buildOrderByClause(clauses?: OrderByClause): string { if (!clauses) return ""; const condition = Object.entries(clauses) @@ -145,10 +239,26 @@ export class QueryBuilder) { + buildJoinClause(clauses?: AnyJoinClauseRecord[]): string { + if (!clauses?.length) return ""; + return clauses + .map((join) => { + const joinType = join.type ? `${join.type} JOIN` : `JOIN`; + const tableName = `${this._databaseName}.${String(join.table)}`; + const as = `AS ${join.as}`; + const on = ["ON", `${String(join.as)}.${String(join.on[0])}`, join.on[1], join.on[2]].join(" "); + return [joinType, tableName, as, on].filter(Boolean).join(" "); + }) + .join(" "); + } + + buildClauses[] | undefined = undefined>( + params?: QueryBuilderParams, + ) { return { - select: this.buildSelectClause(params?.select), - from: this.buildFromClause(), + select: this.buildSelectClause(params?.select, params?.join), + from: this.buildFromClause(params?.join), + join: this.buildJoinClause(params?.join), where: this.buildWhereClause(params?.where), groupBy: this.buildGroupByClause(params?.groupBy), orderBy: this.buildOrderByClause(params?.orderBy), @@ -157,7 +267,9 @@ export class QueryBuilder) { - return Object.values(this.buildClauses(params)).join(" "); + buildQuery[] | undefined = undefined>( + params?: QueryBuilderParams, + ) { + return Object.values(this.buildClauses(params)).join(" ").trim(); } } diff --git a/packages/client/src/table.ts b/packages/client/src/table.ts index baea553..2e357a3 100644 --- a/packages/client/src/table.ts +++ b/packages/client/src/table.ts @@ -4,7 +4,13 @@ import type { FieldPacket, ResultSetHeader, RowDataPacket } from "mysql2/promise import { Column, type ColumnInfo, type ColumnSchema, type ColumnType } from "./column"; import { Connection } from "./connection"; -import { type ExtractQuerySelectedColumn, QueryBuilder, type WhereClause, type QueryBuilderParams } from "./query/builder"; +import { + type ExtractQuerySelectedColumn, + QueryBuilder, + type WhereClause, + type QueryBuilderParams, + type JoinClauseRecord, +} from "./query/builder"; /** * Interface representing the structure of a table type, including its columns. @@ -18,16 +24,17 @@ export interface TableType { /** * Interface representing the schema of a table, including its columns, primary keys, full-text keys, and additional clauses. * + * @typeParam TName - A type extending `string` that defines the name of the table. * @typeParam TType - A type extending `TableType` that defines the structure of the table. * - * @property {string} name - The name of the table. + * @property {TName} name - The name of the table. * @property {Object} columns - An object where each key is a column name and each value is the schema of that column, excluding the name. * @property {string[]} [primaryKeys] - An optional array of column names that form the primary key. * @property {string[]} [fulltextKeys] - An optional array of column names that form full-text keys. * @property {string[]} [clauses] - An optional array of additional SQL clauses for the table definition. */ -export interface TableSchema { - name: string; +export interface TableSchema { + name: TName; columns: { [K in keyof TType["columns"]]: Omit }; primaryKeys?: string[]; fulltextKeys?: string[]; @@ -87,6 +94,7 @@ type VectorScoreKey = "v_score"; * @property {VectorScoreKey} vScoreKey - The key used for vector scoring in vector search queries, defaulting to `"v_score"`. */ export class Table< + TName extends string = string, TType extends TableType = TableType, TDatabaseType extends DatabaseType = DatabaseType, TAi extends AnyAI | undefined = undefined, @@ -97,7 +105,7 @@ export class Table< constructor( private _connection: Connection, public databaseName: string, - public name: string, + public name: TName, private _ai?: TAi, ) { this._path = [databaseName, name].join("."); @@ -144,7 +152,7 @@ export class Table< * * @returns {string} An SQL string representing the table definition. */ - static schemaToClauses(schema: TableSchema): string { + static schemaToClauses(schema: TableSchema): string { const clauses: string[] = [ ...Object.entries(schema.columns).map(([name, schema]) => { return Column.schemaToClauses({ ...schema, name }); @@ -160,6 +168,7 @@ export class Table< /** * Creates a new table in the database with the specified schema. * + * @typeParam TName - The name of the table, which extends `string`. * @typeParam TType - The type of the table, which extends `TableType`. * @typeParam TDatabaseType - The type of the database, which extends `DatabaseType`. * @typeParam TAi - The type of AI functionalities integrated with the table, which can be undefined. @@ -169,24 +178,25 @@ export class Table< * @param {TableSchema} schema - The schema defining the structure of the table. * @param {TAi} [ai] - Optional AI functionalities to associate with the table. * - * @returns {Promise>} A promise that resolves to the created `Table` instance. + * @returns {Promise>} A promise that resolves to the created `Table` instance. */ static async create< + TName extends string = string, TType extends TableType = TableType, TDatabaseType extends DatabaseType = DatabaseType, TAi extends AnyAI | undefined = undefined, >( connection: Connection, databaseName: string, - schema: TableSchema, + schema: TableSchema, ai?: TAi, - ): Promise> { + ): Promise> { const clauses = Table.schemaToClauses(schema); await connection.client.execute(`\ CREATE TABLE IF NOT EXISTS ${databaseName}.${schema.name} (${clauses}) `); - return new Table(connection, databaseName, schema.name, ai); + return new Table(connection, databaseName, schema.name, ai); } /** @@ -333,11 +343,11 @@ export class Table< * * @returns {Promise<(ExtractQuerySelectedColumn & RowDataPacket)[]>} A promise that resolves to an array of selected rows. */ - async find | undefined>( - params?: TParams, - ): Promise<(ExtractQuerySelectedColumn & RowDataPacket)[]> { - type SelectedColumn = ExtractQuerySelectedColumn; - const queryBuilder = new QueryBuilder(this.databaseName, this.name); + async find[] | undefined = undefined>( + params?: QueryBuilderParams, + ) { + type SelectedColumn = ExtractQuerySelectedColumn; + const queryBuilder = new QueryBuilder(this.databaseName, this.name); const query = queryBuilder.buildQuery(params); const [rows] = await this._connection.client.execute<(SelectedColumn & RowDataPacket)[]>(query); return rows; @@ -347,11 +357,14 @@ export class Table< * Updates rows in the table based on the specified values and filters. * * @param {Partial} values - The values to update in the table. - * @param {WhereClause} where - The where clause to apply to the update query. + * @param {WhereClause} where - The where clause to apply to the update query. * * @returns {Promise<[ResultSetHeader, FieldPacket[]]>} A promise that resolves when the update is complete. */ - update(values: Partial, where: WhereClause): Promise<[ResultSetHeader, FieldPacket[]]> { + update( + values: Partial, + where: WhereClause, + ): Promise<[ResultSetHeader, FieldPacket[]]> { const _where = new QueryBuilder(this.databaseName, this.name).buildWhereClause(where); const columnAssignments = Object.keys(values) @@ -365,11 +378,11 @@ export class Table< /** * Deletes rows from the table based on the specified filters. If no filters are provided, the table is truncated. * - * @param {WhereClause} [where] - The where clause to apply to the delete query. + * @param {WhereClause} [where] - The where clause to apply to the delete query. * * @returns {Promise<[ResultSetHeader, FieldPacket[]]>} A promise that resolves when the delete operation is complete. */ - delete(where?: WhereClause): Promise<[ResultSetHeader, FieldPacket[]]> { + delete(where?: WhereClause): Promise<[ResultSetHeader, FieldPacket[]]> { if (!where) return this.truncate(); const _where = new QueryBuilder(this.databaseName, this.name).buildWhereClause(where); const query = `DELETE FROM ${this._path} ${_where}`; @@ -405,7 +418,7 @@ export class Table< vectorColumn: TableColumnName; embeddingParams?: TAi extends AnyAI ? Parameters[1] : never; }, - TQueryParams extends QueryBuilderParams, + TQueryParams extends QueryBuilderParams, >( params: TParams, queryParams?: TQueryParams, @@ -415,7 +428,7 @@ export class Table< type SelectedColumn = ExtractQuerySelectedColumn; type ResultColumn = SelectedColumn & { [K in VectorScoreKey]: number }; - const clauses = new QueryBuilder(this.databaseName, this.name).buildClauses(queryParams); + const clauses = new QueryBuilder(this.databaseName, this.name).buildClauses(queryParams); const promptEmbedding = (await this.ai.embeddings.create(params.prompt, params.embeddingParams))[0] || []; let orderByClause = `ORDER BY ${this.vScoreKey} DESC`; @@ -464,7 +477,7 @@ export class Table< async createChatCompletion< TParams extends Parameters[0] & (TAi extends AnyAI ? Parameters[0] : never) & { template?: string }, - TQueryParams extends QueryBuilderParams, + TQueryParams extends QueryBuilderParams, >(params: TParams, queryParams?: TQueryParams): Promise> { const { prompt, systemRole, template, vectorColumn, embeddingParams, ...createChatCompletionParams } = params; From 195f7649ccf8acc3075d70a2891e014d75db59fc Mon Sep 17 00:00:00 2001 From: demenskyi Date: Wed, 28 Aug 2024 21:48:38 -0700 Subject: [PATCH 2/2] Join support reverted --- examples/estore/index.ts | 2 +- packages/client/src/query/builder.ts | 168 +++++++-------------------- packages/client/src/table.ts | 39 +++---- packages/rag/src/chat/index.ts | 62 ++++++---- packages/rag/src/chat/message.ts | 16 +-- packages/rag/src/chat/session.ts | 52 +++++---- packages/rag/src/index.ts | 58 +++++---- 7 files changed, 174 insertions(+), 223 deletions(-) diff --git a/examples/estore/index.ts b/examples/estore/index.ts index dee8458..8ef01f7 100644 --- a/examples/estore/index.ts +++ b/examples/estore/index.ts @@ -155,7 +155,7 @@ async function main() { console.log("Executing database methods..."); console.log('Creating "users" table...'); - const newUsersTable = await db.createTable({ + const newUsersTable = await db.createTable<"users", Database["tables"]["users"]>({ name: "users", columns: { id: { type: "bigint", autoIncrement: true, primaryKey: true }, diff --git a/packages/client/src/query/builder.ts b/packages/client/src/query/builder.ts index 08f90c0..311c1ad 100644 --- a/packages/client/src/query/builder.ts +++ b/packages/client/src/query/builder.ts @@ -1,22 +1,14 @@ import { escape } from "mysql2"; import type { DatabaseType } from "../database"; +import type { TableType } from "../table"; export type SelectClause< TTableName extends string, + TTableType extends TableType, TDatabaseType extends DatabaseType, - TJoin extends AnyJoinClauseRecord[] | undefined = undefined, - _TTableColumns = TDatabaseType["tables"][TTableName]["columns"], -> = ( - | (string & {}) - | (TJoin extends AnyJoinClauseRecord[] - ? - | `${TTableName}.${Extract}` - | { - [K in TJoin[number] as K["as"]]: `${K["as"]}.${Extract}`; - }[TJoin[number]["as"]] - : keyof _TTableColumns) -)[]; + _TTableColumns = TTableType["columns"], +> = ((string & {}) | keyof _TTableColumns)[]; export type WhereOperator = TColumnValue extends string ? { @@ -41,129 +33,69 @@ export type WhereOperator = TColumnValue extends string export type WhereClause< TTableName extends string, + TTableType extends TableType, TDatabaseType extends DatabaseType, - TJoin extends AnyJoinClauseRecord[] | undefined = undefined, - _TTableColumns = TDatabaseType["tables"][TTableName]["columns"], -> = (TJoin extends AnyJoinClauseRecord[] - ? { - [K in keyof _TTableColumns as `${TTableName}.${Extract}`]?: - | WhereOperator<_TTableColumns[K]> - | _TTableColumns[K]; - } & { - [K in TJoin[number] as K["as"]]: { - [C in keyof TDatabaseType["tables"][K["table"]]["columns"] as `${K["as"]}.${Extract}`]?: - | WhereOperator - | TDatabaseType["tables"][K["table"]]["columns"][C]; - }; - }[TJoin[number]["as"]] - : { [K in keyof _TTableColumns]?: WhereOperator<_TTableColumns[K]> | _TTableColumns[K] }) & { - OR?: WhereClause[]; - NOT?: WhereClause; + _TTableColumns = TTableType["columns"], +> = { [K in keyof _TTableColumns]?: WhereOperator<_TTableColumns[K]> | _TTableColumns[K] } & { + OR?: WhereClause[]; + NOT?: WhereClause; }; export type GroupByClause< TTableName extends string, + TTableType extends TableType, TDatabaseType extends DatabaseType, - TJoin extends AnyJoinClauseRecord[] | undefined = undefined, - _TTableColumns = TDatabaseType["tables"][TTableName]["columns"], -> = ( - | (string & {}) - | (TJoin extends AnyJoinClauseRecord[] - ? - | `${TTableName}.${Extract}` - | { - [K in TJoin[number] as K["as"]]: `${K["as"]}.${Extract}`; - }[TJoin[number]["as"]] - : keyof _TTableColumns) -)[]; + _TTableColumns = TTableType["columns"], +> = ((string & {}) | keyof _TTableColumns)[]; export type OrderByDirection = "asc" | "desc"; export type OrderByClause< TTableName extends string, + TTableType extends TableType, TDatabaseType extends DatabaseType, - TJoin extends AnyJoinClauseRecord[] | undefined = undefined, - _TTableColumns = TDatabaseType["tables"][TTableName]["columns"], -> = { [K in string & {}]: OrderByDirection } & (TJoin extends AnyJoinClauseRecord[] - ? { [K in keyof _TTableColumns as `${TTableName}.${Extract}`]?: OrderByDirection } & { - [K in TJoin[number] as K["as"]]: { - [C in keyof TDatabaseType["tables"][K["table"]]["columns"] as `${K["as"]}.${Extract}`]?: OrderByDirection; - }; - }[TJoin[number]["as"]] - : { [K in keyof _TTableColumns]?: OrderByDirection }); - -type JoinType = "INNER" | "LEFT" | "RIGHT" | "FULL"; -type JoinOperator = "=" | "<" | ">" | "<=" | ">=" | "!=" | "<=>"; - -export interface JoinClause< - TTableName extends string, - TDatabaseType extends DatabaseType, - TTable extends string, - TAs extends string, -> { - type?: JoinType; - table: TTable; - as: TAs; - on: [ - left: keyof TDatabaseType["tables"][TTable]["columns"], - operator: JoinOperator, - right: (string & {}) | `${TTableName}.${Extract}`, - ]; -} - -export type JoinClauseRecord = { - [K in keyof TDatabaseType["tables"]]: { - [Alias in TAs]: JoinClause, Alias>; - }[TAs]; -}[keyof TDatabaseType["tables"]]; - -export type AnyJoinClauseRecord = JoinClauseRecord; + _TTableColumns = TTableType["columns"], +> = { [K in string & {}]: OrderByDirection } & { [K in keyof _TTableColumns]?: OrderByDirection }; export interface QueryBuilderParams< TTableName extends string, + TTableType extends TableType, TDatabaseType extends DatabaseType, - TJoin extends AnyJoinClauseRecord[] | undefined = undefined, > { - select?: SelectClause; - where?: WhereClause; - groupBy?: GroupByClause; - orderBy?: OrderByClause; + select?: SelectClause; + where?: WhereClause; + groupBy?: GroupByClause; + orderBy?: OrderByClause; limit?: number; offset?: number; - join?: TJoin; } -export type ExtractQuerySelectedColumn | undefined> = - TParams extends QueryBuilderParams - ? TParams["select"] extends (keyof TColumn)[] - ? Pick - : TColumn - : TColumn; +export type AnyQueryBuilderParams = QueryBuilderParams; -export class QueryBuilder { +export type ExtractQuerySelectedColumn< + TTableName extends string, + TDatabaseType extends DatabaseType, + TParams extends AnyQueryBuilderParams | undefined, + _Table extends TDatabaseType["tables"][TTableName] = TDatabaseType["tables"][TTableName], +> = TParams extends AnyQueryBuilderParams + ? TParams["select"] extends (keyof _Table["columns"])[] + ? Pick<_Table["columns"], TParams["select"][number]> + : _Table["columns"] + : _Table["columns"]; + +export class QueryBuilder { constructor( private _databaseName: string, private _tableName: TName, ) {} - buildSelectClause(select?: SelectClause, join?: AnyJoinClauseRecord[]) { - let columns = select ? select : ["*"]; - - if (join?.length) { - columns = columns.map((column) => { - if (String(column).includes(".")) { - const [tableName, columnName] = String(column).split("."); - return `${String(column)} AS ${tableName}_${columnName}`; - } - return column; - }); - } - + buildSelectClause(select?: SelectClause) { + const columns = select ? select : ["*"]; return `SELECT ${columns.join(", ")}`; } - buildFromClause(join?: AnyJoinClauseRecord[]) { - return `FROM ${this._databaseName}.${String(this._tableName)}${join?.length ? ` AS ${this._tableName}` : ""}`; + buildFromClause() { + return `FROM ${this._databaseName}.${String(this._tableName)}`; } buildWhereCondition(column: string, operator: string, value: any): string { @@ -239,26 +171,10 @@ export class QueryBuilder { - const joinType = join.type ? `${join.type} JOIN` : `JOIN`; - const tableName = `${this._databaseName}.${String(join.table)}`; - const as = `AS ${join.as}`; - const on = ["ON", `${String(join.as)}.${String(join.on[0])}`, join.on[1], join.on[2]].join(" "); - return [joinType, tableName, as, on].filter(Boolean).join(" "); - }) - .join(" "); - } - - buildClauses[] | undefined = undefined>( - params?: QueryBuilderParams, - ) { + buildClauses>(params?: TParams) { return { - select: this.buildSelectClause(params?.select, params?.join), - from: this.buildFromClause(params?.join), - join: this.buildJoinClause(params?.join), + select: this.buildSelectClause(params?.select), + from: this.buildFromClause(), where: this.buildWhereClause(params?.where), groupBy: this.buildGroupByClause(params?.groupBy), orderBy: this.buildOrderByClause(params?.orderBy), @@ -267,9 +183,7 @@ export class QueryBuilder[] | undefined = undefined>( - params?: QueryBuilderParams, - ) { + buildQuery>(params?: TParams) { return Object.values(this.buildClauses(params)).join(" ").trim(); } } diff --git a/packages/client/src/table.ts b/packages/client/src/table.ts index 2e357a3..a00d5aa 100644 --- a/packages/client/src/table.ts +++ b/packages/client/src/table.ts @@ -4,13 +4,7 @@ import type { FieldPacket, ResultSetHeader, RowDataPacket } from "mysql2/promise import { Column, type ColumnInfo, type ColumnSchema, type ColumnType } from "./column"; import { Connection } from "./connection"; -import { - type ExtractQuerySelectedColumn, - QueryBuilder, - type WhereClause, - type QueryBuilderParams, - type JoinClauseRecord, -} from "./query/builder"; +import { type ExtractQuerySelectedColumn, QueryBuilder, type WhereClause, type QueryBuilderParams } from "./query/builder"; /** * Interface representing the structure of a table type, including its columns. @@ -341,13 +335,11 @@ export class Table< * * @param {TParams} params - The arguments defining the query, including selected columns, filters, and other options. * - * @returns {Promise<(ExtractQuerySelectedColumn & RowDataPacket)[]>} A promise that resolves to an array of selected rows. + * @returns {Promise<(ExtractQuerySelectedColumn & RowDataPacket)[]>} A promise that resolves to an array of selected rows. */ - async find[] | undefined = undefined>( - params?: QueryBuilderParams, - ) { - type SelectedColumn = ExtractQuerySelectedColumn; - const queryBuilder = new QueryBuilder(this.databaseName, this.name); + async find>(params?: TParams) { + type SelectedColumn = ExtractQuerySelectedColumn; + const queryBuilder = new QueryBuilder(this.databaseName, this.name); const query = queryBuilder.buildQuery(params); const [rows] = await this._connection.client.execute<(SelectedColumn & RowDataPacket)[]>(query); return rows; @@ -363,7 +355,7 @@ export class Table< */ update( values: Partial, - where: WhereClause, + where: WhereClause, ): Promise<[ResultSetHeader, FieldPacket[]]> { const _where = new QueryBuilder(this.databaseName, this.name).buildWhereClause(where); @@ -382,7 +374,7 @@ export class Table< * * @returns {Promise<[ResultSetHeader, FieldPacket[]]>} A promise that resolves when the delete operation is complete. */ - delete(where?: WhereClause): Promise<[ResultSetHeader, FieldPacket[]]> { + delete(where?: WhereClause): Promise<[ResultSetHeader, FieldPacket[]]> { if (!where) return this.truncate(); const _where = new QueryBuilder(this.databaseName, this.name).buildWhereClause(where); const query = `DELETE FROM ${this._path} ${_where}`; @@ -408,7 +400,7 @@ export class Table< * @param {TQueryParams} [queryParams] - Optional query builder parameters to refine the search, such as filters, * groupings, orderings, limits, and offsets. * - * @returns {Promise<(ExtractQuerySelectedColumn & { [K in VectorScoreKey]: number } & RowDataPacket)[]>} + * @returns {Promise<(ExtractQuerySelectedColumn & { v_score: number } & RowDataPacket)[]>} * A promise that resolves to an array of rows matching the vector search criteria, each row including * the selected columns and a vector similarity score. */ @@ -418,17 +410,12 @@ export class Table< vectorColumn: TableColumnName; embeddingParams?: TAi extends AnyAI ? Parameters[1] : never; }, - TQueryParams extends QueryBuilderParams, - >( - params: TParams, - queryParams?: TQueryParams, - ): Promise< - (ExtractQuerySelectedColumn & { [K in VectorScoreKey]: number } & RowDataPacket)[] - > { - type SelectedColumn = ExtractQuerySelectedColumn; + TQueryParams extends QueryBuilderParams, + >(params: TParams, queryParams?: TQueryParams) { + type SelectedColumn = ExtractQuerySelectedColumn; type ResultColumn = SelectedColumn & { [K in VectorScoreKey]: number }; - const clauses = new QueryBuilder(this.databaseName, this.name).buildClauses(queryParams); + const clauses = new QueryBuilder(this.databaseName, this.name).buildClauses(queryParams); const promptEmbedding = (await this.ai.embeddings.create(params.prompt, params.embeddingParams))[0] || []; let orderByClause = `ORDER BY ${this.vScoreKey} DESC`; @@ -477,7 +464,7 @@ export class Table< async createChatCompletion< TParams extends Parameters[0] & (TAi extends AnyAI ? Parameters[0] : never) & { template?: string }, - TQueryParams extends QueryBuilderParams, + TQueryParams extends QueryBuilderParams, >(params: TParams, queryParams?: TQueryParams): Promise> { const { prompt, systemRole, template, vectorColumn, embeddingParams, ...createChatCompletionParams } = params; diff --git a/packages/rag/src/chat/index.ts b/packages/rag/src/chat/index.ts index 2771b39..b4032e4 100644 --- a/packages/rag/src/chat/index.ts +++ b/packages/rag/src/chat/index.ts @@ -60,6 +60,9 @@ export class Chat< TDatabase extends AnyDatabase = AnyDatabase, TAi extends AnyAI = AnyAI, TChatCompletionTool extends AnyChatCompletionTool[] | undefined = undefined, + TTableName extends string = string, + TSessionsTableName extends string = string, + TMessagesTableName extends string = string, > { constructor( private _database: TDatabase, @@ -70,9 +73,9 @@ export class Chat< public name: string, public systemRole: ChatSession["systemRole"], public store: ChatSession["store"], - public tableName: string, - public sessionsTableName: ChatSession["tableName"], - public messagesTableName: ChatSession["messagesTableName"], + public tableName: TTableName, + public sessionsTableName: TSessionsTableName, + public messagesTableName: TMessagesTableName, ) {} /** @@ -89,8 +92,8 @@ export class Chat< private static _createTable( database: TDatabase, name: TName, - ): Promise> { - return database.createTable({ + ): Promise> { + return database.createTable({ name, columns: { id: { type: "bigint", autoIncrement: true, primaryKey: true }, @@ -116,13 +119,22 @@ export class Chat< * @param {TAi} ai - The AI instance used in the chat. * @param {TConfig} [config] - The configuration object for the chat. * - * @returns {Promise>} A promise that resolves to the created `Chat` instance. + * @returns {Promise>} A promise that resolves to the created `Chat` instance. */ static async create( database: TDatabase, ai: TAi, config?: TConfig, - ): Promise> { + ): Promise< + Chat< + TDatabase, + TAi, + TConfig["tools"], + TConfig["tableName"] extends string ? TConfig["tableName"] : string, + TConfig["sessionsTableName"] extends string ? TConfig["sessionsTableName"] : string, + TConfig["messagesTableName"] extends string ? TConfig["messagesTableName"] : string + > + > { const createdAt: Chat["createdAt"] = new Date().toISOString().replace("T", " ").substring(0, 23); const _config: ChatConfig = { @@ -157,7 +169,7 @@ export class Chat< id = rows?.[0].insertId; } - return new Chat( + return new Chat( database, ai, _config.tools, @@ -188,7 +200,7 @@ export class Chat< tableName: Chat["tableName"], sessionsTable: Chat["sessionsTableName"], messagesTableName: Chat["messagesTableName"], - where?: Parameters["delete"]>[0], + where?: Parameters["delete"]>[0], ): Promise<[[ResultSetHeader, FieldPacket[]], [ResultSetHeader, FieldPacket[]][]]> { const table = database.table(tableName); const deletedRowIds = await table.find({ select: ["id"], where }); @@ -206,7 +218,7 @@ export class Chat< * * @returns {Promise<[ResultSetHeader, FieldPacket[]]>} A promise that resolves when the update operation is complete. */ - async update(values: Parameters["update"]>[0]): Promise<[ResultSetHeader, FieldPacket[]]> { + async update(values: Parameters["update"]>[0]): Promise<[ResultSetHeader, FieldPacket[]]> { const result = await this._database.table(this.tableName).update(values, { id: this.id }); for (const [key, value] of Object.entries(values)) { @@ -234,9 +246,19 @@ export class Chat< * * @param {TName} [name] - The name of the session. * - * @returns {Promise>} A promise that resolves to the created `ChatSession` instance. + * @returns {Promise>} A promise that resolves to the created `ChatSession` instance. */ - createSession(name?: TName): Promise> { + createSession( + name?: TName, + ): Promise< + ChatSession< + TDatabase, + TAi, + TChatCompletionTool, + TSessionsTableName extends string ? TSessionsTableName : string, + TMessagesTableName extends string ? TMessagesTableName : string + > + > { return ChatSession.create(this._database, this._ai, { chatId: this.id, name, @@ -249,18 +271,18 @@ export class Chat< } /** - * Selects chat sessions from the current chat based on the provided filters and options. + * Finds chat sessions from the current chat based on the provided filters and options. * - * @typeParam T - The parameters passed to the `select` method of the `Table` class. + * @typeParam T - The parameters passed to the `find` method of the `Table` class. * - * @param {...T} args - The arguments defining the filters and options for selecting sessions. + * @param {params} params - The parameters defining the filters and options for finding sessions. * - * @returns {Promise[]>} A promise that resolves to an array of `ChatSession` instances representing the selected sessions. + * @returns {Promise[]>} A promise that resolves to an array of `ChatSession` instances representing the found sessions. */ - async selectSessions>["find"]>>( - ...args: T - ): Promise[]> { - const rows = await this._database.table(this.sessionsTableName).find(...args); + async findSessions( + params?: Parameters>["find"]>[0], + ): Promise[]> { + const rows = await this._database.table(this.sessionsTableName).find(params); return rows.map( (row) => diff --git a/packages/rag/src/chat/message.ts b/packages/rag/src/chat/message.ts index a9b554e..5b4f0e0 100644 --- a/packages/rag/src/chat/message.ts +++ b/packages/rag/src/chat/message.ts @@ -37,7 +37,7 @@ export interface ChatMessagesTable { * @property {boolean} store - Whether the message is stored in the database. * @property {string} tableName - The name of the table where the message is stored. */ -export class ChatMessage { +export class ChatMessage { constructor( private _database: TDatabase, public id: number | undefined, @@ -46,7 +46,7 @@ export class ChatMessage { public role: ChatCompletionMessage["role"], public content: ChatCompletionMessage["content"], public store: boolean, - public tableName: string, + public tableName: TTableName, ) {} /** @@ -63,8 +63,8 @@ export class ChatMessage { static createTable( database: TDatabase, name: TName, - ): Promise> { - return database.createTable({ + ): Promise> { + return database.createTable({ name, columns: { id: { type: "bigint", autoIncrement: true, primaryKey: true }, @@ -90,7 +90,7 @@ export class ChatMessage { static async create( database: TDatabase, config: TConfig, - ): Promise> { + ): Promise> { const { sessionId, role, content, store, tableName } = config; const createdAt: ChatMessage["createdAt"] = new Date().toISOString().replace("T", " ").substring(0, 23); let id: ChatMessage["id"]; @@ -115,7 +115,7 @@ export class ChatMessage { static delete( database: AnyDatabase, tableName: ChatMessage["tableName"], - where?: Parameters["delete"]>[0], + where?: Parameters["delete"]>[0], ): Promise<[ResultSetHeader, FieldPacket[]]> { return database.table(tableName).delete(where); } @@ -127,7 +127,9 @@ export class ChatMessage { * * @returns {Promise<[ResultSetHeader, FieldPacket[]]>} A promise that resolves when the update operation is complete. */ - async update(values: Parameters["update"]>[0]): Promise<[ResultSetHeader, FieldPacket[]]> { + async update( + values: Parameters["update"]>[0], + ): Promise<[ResultSetHeader, FieldPacket[]]> { const result = await this._database.table(this.tableName).update(values, { id: this.id }); for (const [key, value] of Object.entries(values)) { diff --git a/packages/rag/src/chat/session.ts b/packages/rag/src/chat/session.ts index c4cdc34..5e95a43 100644 --- a/packages/rag/src/chat/session.ts +++ b/packages/rag/src/chat/session.ts @@ -70,6 +70,8 @@ export class ChatSession< TDatabase extends AnyDatabase = AnyDatabase, TAi extends AnyAI = AnyAI, TChatCompletionTool extends AnyChatCompletionTool[] | undefined = undefined, + TTableName extends string = string, + TMessagesTableName extends string = string, > { constructor( private _database: TDatabase, @@ -81,8 +83,8 @@ export class ChatSession< public name: string, public systemRole: string, public store: ChatMessage["store"], - public tableName: string, - public messagesTableName: ChatMessage["tableName"], + public tableName: TTableName, + public messagesTableName: TMessagesTableName, ) {} /** @@ -99,8 +101,8 @@ export class ChatSession< static createTable( database: TDatabase, name: TName, - ): Promise> { - return database.createTable({ + ): Promise> { + return database.createTable({ name, columns: { id: { type: "bigint", autoIncrement: true, primaryKey: true }, @@ -122,13 +124,21 @@ export class ChatSession< * @param {TAi} ai - The AI instance used in the chat session. * @param {TConfig} [config] - The configuration object for the chat session. * - * @returns {Promise>} A promise that resolves to the created `ChatSession` instance. + * @returns {Promise>} A promise that resolves to the created `ChatSession` instance. */ static async create>( database: TDatabase, ai: TAi, config?: TConfig, - ): Promise> { + ): Promise< + ChatSession< + TDatabase, + TAi, + TConfig["tools"], + TConfig["tableName"] extends string ? TConfig["tableName"] : string, + TConfig["messagesTableName"] extends string ? TConfig["messagesTableName"] : string + > + > { const createdAt: ChatSession["createdAt"] = new Date().toISOString().replace("T", " ").substring(0, 23); const _config: ChatSessionConfig = { @@ -150,7 +160,7 @@ export class ChatSession< id = rows?.[0].insertId; } - return new ChatSession( + return new ChatSession( database, ai, _config.tools, @@ -179,7 +189,7 @@ export class ChatSession< database: AnyDatabase, tableName: ChatSession["tableName"], messagesTableName: ChatSession["messagesTableName"], - where?: Parameters["delete"]>[0], + where?: Parameters["delete"]>[0], ): Promise<[ResultSetHeader, FieldPacket[]][]> { const table = database.table(tableName); const deletedRowIds = await table.find({ select: ["id"], where }); @@ -197,7 +207,9 @@ export class ChatSession< * * @returns {Promise<[ResultSetHeader, FieldPacket[]]>} A promise that resolves when the update operation is complete. */ - async update(values: Parameters["update"]>[0]): Promise<[ResultSetHeader, FieldPacket[]]> { + async update( + values: Parameters["update"]>[0], + ): Promise<[ResultSetHeader, FieldPacket[]]> { const result = await this._database.table(this.tableName).update(values, { id: this.id }); for (const [key, value] of Object.entries(values)) { @@ -227,12 +239,12 @@ export class ChatSession< * @param {TRole} role - The role of the message sender. * @param {TContent} content - The content of the chat message. * - * @returns {Promise>} A promise that resolves to the created `ChatMessage` instance. + * @returns {Promise>} A promise that resolves to the created `ChatMessage` instance. */ createMessage( role: TRole, content: TContent, - ): Promise> { + ): Promise> { return ChatMessage.create(this._database, { sessionId: this.id, role, @@ -243,18 +255,16 @@ export class ChatSession< } /** - * Selects chat messages from the current session based on the provided filters and options. - * - * @typeParam T - The parameters passed to the `select` method of the `Table` class. + * Finds chat messages from the current session based on the provided filters and options. * - * @param {...T} args - The arguments defining the filters and options for selecting messages. + * @param {params} params - The parameters defining the filters and options for finding messages. * - * @returns {Promise[]>} A promise that resolves to an array of `ChatMessage` instances representing the selected messages. + * @returns {Promise[]>} A promise that resolves to an array of `ChatMessage` instances representing the found messages. */ - async selectMessages>["find"]>>( - ...args: T - ): Promise[]> { - const rows = await this._database.table(this.messagesTableName).find(...args); + async findMessages( + params?: Parameters>["find"]>[0], + ): Promise[]> { + const rows = await this._database.table(this.messagesTableName).find(params); return rows.map((row) => { return new ChatMessage( @@ -320,7 +330,7 @@ export class ChatSession< if (loadDatabaseSchema || loadHistory) { const [databaseSchema, historyMessages] = await Promise.all([ loadDatabaseSchema ? this._database.describe() : undefined, - loadHistory ? this.selectMessages({ orderBy: { createdAt: "asc" } }) : undefined, + loadHistory ? this.findMessages({ orderBy: { createdAt: "asc" } }) : undefined, ]); if (databaseSchema) { diff --git a/packages/rag/src/index.ts b/packages/rag/src/index.ts index 478e947..2472392 100644 --- a/packages/rag/src/index.ts +++ b/packages/rag/src/index.ts @@ -1,14 +1,5 @@ import type { AnyAI, AnyChatCompletionTool } from "@singlestore/ai"; -import type { - AnyDatabase, - Database, - DatabaseType, - FieldPacket, - InferDatabaseType, - QueryBuilderParams, - ResultSetHeader, - Table, -} from "@singlestore/client"; +import type { AnyDatabase, FieldPacket, InferDatabaseType, ResultSetHeader, Table } from "@singlestore/client"; import { Chat, type CreateChatConfig, type ChatsTable } from "./chat"; @@ -63,31 +54,56 @@ export class RAG>} A promise that resolves to the created `Chat` instance. + * @returns {Promise>} A promise that resolves to the created `Chat` instance. */ - createChat(config?: TConfig): Promise> { + createChat( + config?: TConfig, + ): Promise< + Chat< + TDatabase, + TAi, + TConfig["tools"], + TConfig["tableName"] extends string ? TConfig["tableName"] : string, + TConfig["sessionsTableName"] extends string ? TConfig["sessionsTableName"] : string, + TConfig["messagesTableName"] extends string ? TConfig["messagesTableName"] : string + > + > { return Chat.create(this._database, this._ai, config); } /** - * Selects chat instances from the database based on the provided configuration and arguments. + * Finds chat instances from the database based on the provided configuration and parameters. * - * @typeParam TConfig - The configuration object for selecting chats. - * @typeParam TSelectArgs - The parameters passed to the `select` method of the `Table` class. + * @typeParam TConfig - The configuration object for finding chats. * - * @param {TConfig} [config] - The configuration object for selecting chats. - * @param {...TSelectArgs} selectArgs - The arguments defining the filters and options for selecting chats. + * @param {TConfig} [config] - The configuration object for finding chats. + * @param {findParams} findParams - The parameters defining the filters and options for finding chats. * - * @returns {Promise[]>} A promise that resolves to an array of `Chat` instances representing the selected chats. + * @returns {Promise[]>} A promise that resolves to an array of `Chat` instances representing the found chats. */ async findChats( config?: TConfig, - findParams?: QueryBuilderParams>, - ): Promise[]> { + findParams?: Parameters< + Table< + TConfig["tableName"] extends string ? TConfig["tableName"] : string, + ChatsTable, + InferDatabaseType + >["find"] + >[0], + ): Promise< + Chat< + TDatabase, + TAi, + TConfig["tools"], + TConfig["tableName"] extends string ? TConfig["tableName"] : string, + string, + string + >[] + > { const rows = await this._database.table(config?.tableName || "chats").find(findParams); return rows.map( (row) => - new Chat( + new Chat( this._database, this._ai, config?.tools || [],