diff --git a/package.json b/package.json index db84f87..1cb8bb2 100644 --- a/package.json +++ b/package.json @@ -257,6 +257,11 @@ "type": "object", "description": "If set, migrations will be applied for all analyses. If the current file is a migration file, execution will run until the previous migration." }, + "plpgsqlLanguageServer.plpgsqlCheckSchema": { + "scope": "resource", + "type": "string", + "description": "Schema where plpgsql_check is installed." + }, "plpgsqlLanguageServer.enableExecuteFileQueryCommand": { "scope": "resource", "type": "boolean", diff --git a/sample/definitions/trigger/static_error_disabled.pgsql b/sample/definitions/trigger/static_error_disabled.pgsql new file mode 100644 index 0000000..6f9c1a0 --- /dev/null +++ b/sample/definitions/trigger/static_error_disabled.pgsql @@ -0,0 +1,24 @@ +DROP TABLE IF EXISTS users_2 CASCADE; + +CREATE TABLE users_2 ( + id integer not null PRIMARY KEY +); + +create or replace function update_updated_at_column () + returns trigger + language plpgsql + as $function$ +begin + new.updated_at = NOW(); + return new; +end; +$function$; + +-- plpgsql-language-server:disable-static +create trigger update_users_2_modtime_disabled -- error silenced + before update on users_2 for each row + execute function update_updated_at_column (); + +create trigger update_users_2_modtime -- should raise error + before update on users_2 for each row + execute function update_updated_at_column (); diff --git a/sample/definitions/trigger/static_error_disabled.pgsql.json b/sample/definitions/trigger/static_error_disabled.pgsql.json new file mode 100644 index 0000000..8bc91f8 --- /dev/null +++ b/sample/definitions/trigger/static_error_disabled.pgsql.json @@ -0,0 +1,184 @@ +[ + { + "RawStmt": { + "stmt": { + "DropStmt": { + "objects": [ + { + "List": { + "items": [ + { + "String": { + "str": "users_2" + } + } + ] + } + } + ], + "removeType": "OBJECT_TABLE", + "behavior": "DROP_CASCADE", + "missing_ok": true + } + }, + "stmt_len": 36 + } + }, + { + "RawStmt": { + "stmt": { + "CreateStmt": { + "relation": { + "relname": "users_2", + "inh": true, + "relpersistence": "p" + }, + "tableElts": [ + { + "ColumnDef": { + "colname": "id", + "typeName": { + "names": [ + { + "String": { + "str": "pg_catalog" + } + }, + { + "String": { + "str": "int4" + } + } + ], + "typemod": -1 + }, + "is_local": true, + "constraints": [ + { + "Constraint": { + "contype": "CONSTR_NOTNULL" + } + }, + { + "Constraint": { + "contype": "CONSTR_PRIMARY" + } + } + ] + } + } + ], + "oncommit": "ONCOMMIT_NOOP" + } + }, + "stmt_len": 60 + } + }, + { + "RawStmt": { + "stmt": { + "CreateFunctionStmt": { + "replace": true, + "funcname": [ + { + "String": { + "str": "update_updated_at_column" + } + } + ], + "returnType": { + "names": [ + { + "String": { + "str": "trigger" + } + } + ], + "typemod": -1 + }, + "options": [ + { + "DefElem": { + "defname": "language", + "arg": { + "String": { + "str": "plpgsql" + } + }, + "defaction": "DEFELEM_UNSPEC" + } + }, + { + "DefElem": { + "defname": "as", + "arg": { + "List": { + "items": [ + { + "String": { + "str": "\nbegin\n new.updated_at = NOW();\n return new;\nend;\n" + } + } + ] + } + }, + "defaction": "DEFELEM_UNSPEC" + } + } + ] + } + }, + "stmt_len": 171 + } + }, + { + "RawStmt": { + "stmt": { + "CreateTrigStmt": { + "trigname": "update_users_2_modtime_disabled", + "relation": { + "relname": "users_2", + "inh": true, + "relpersistence": "p" + }, + "funcname": [ + { + "String": { + "str": "update_updated_at_column" + } + } + ], + "row": true, + "timing": 2, + "events": 16 + } + }, + "stmt_len": 195 + } + }, + { + "RawStmt": { + "stmt": { + "CreateTrigStmt": { + "trigname": "update_users_2_modtime", + "relation": { + "relname": "users_2", + "inh": true, + "relpersistence": "p" + }, + "funcname": [ + { + "String": { + "str": "update_updated_at_column" + } + } + ], + "row": true, + "timing": 2, + "events": 16 + } + }, + "stmt_len": 148 + } + } +] \ No newline at end of file diff --git a/sample/definitions/trigger/static_error_trigger_column_does_not_exist.pgsql b/sample/definitions/trigger/static_error_trigger_column_does_not_exist.pgsql new file mode 100644 index 0000000..4552e2f --- /dev/null +++ b/sample/definitions/trigger/static_error_trigger_column_does_not_exist.pgsql @@ -0,0 +1,36 @@ +DROP TABLE IF EXISTS users_1 CASCADE; +DROP TABLE IF EXISTS users_2 CASCADE; +DROP TABLE IF EXISTS users_3 CASCADE; + +CREATE TABLE users_1 ( + id integer not null PRIMARY KEY, + updated_at timestamp with time zone not null DEFAULT now() +); +CREATE TABLE users_2 ( + id integer not null PRIMARY KEY +); +CREATE TABLE users_3 ( + id integer not null PRIMARY KEY +); + +create or replace function update_updated_at_column () + returns trigger + language plpgsql + as $function$ +begin + new.updated_at = NOW(); + return new; +end; +$function$; + +create trigger update_users_3_modtime -- should raise error + before update on users_3 for each row + execute function update_updated_at_column (); + +create trigger update_users_1_modtime + before update on users_1 for each row + execute function update_updated_at_column (); + +create trigger update_users_2_modtime -- should raise error + before update on users_2 for each row + execute function update_updated_at_column (); diff --git a/sample/definitions/trigger/static_error_trigger_column_does_not_exist.pgsql.json b/sample/definitions/trigger/static_error_trigger_column_does_not_exist.pgsql.json new file mode 100644 index 0000000..bda322f --- /dev/null +++ b/sample/definitions/trigger/static_error_trigger_column_does_not_exist.pgsql.json @@ -0,0 +1,403 @@ +[ + { + "RawStmt": { + "stmt": { + "DropStmt": { + "objects": [ + { + "List": { + "items": [ + { + "String": { + "str": "users_1" + } + } + ] + } + } + ], + "removeType": "OBJECT_TABLE", + "behavior": "DROP_CASCADE", + "missing_ok": true + } + }, + "stmt_len": 36 + } + }, + { + "RawStmt": { + "stmt": { + "DropStmt": { + "objects": [ + { + "List": { + "items": [ + { + "String": { + "str": "users_2" + } + } + ] + } + } + ], + "removeType": "OBJECT_TABLE", + "behavior": "DROP_CASCADE", + "missing_ok": true + } + }, + "stmt_len": 37 + } + }, + { + "RawStmt": { + "stmt": { + "DropStmt": { + "objects": [ + { + "List": { + "items": [ + { + "String": { + "str": "users_3" + } + } + ] + } + } + ], + "removeType": "OBJECT_TABLE", + "behavior": "DROP_CASCADE", + "missing_ok": true + } + }, + "stmt_len": 37 + } + }, + { + "RawStmt": { + "stmt": { + "CreateStmt": { + "relation": { + "relname": "users_1", + "inh": true, + "relpersistence": "p" + }, + "tableElts": [ + { + "ColumnDef": { + "colname": "id", + "typeName": { + "names": [ + { + "String": { + "str": "pg_catalog" + } + }, + { + "String": { + "str": "int4" + } + } + ], + "typemod": -1 + }, + "is_local": true, + "constraints": [ + { + "Constraint": { + "contype": "CONSTR_NOTNULL" + } + }, + { + "Constraint": { + "contype": "CONSTR_PRIMARY" + } + } + ] + } + }, + { + "ColumnDef": { + "colname": "updated_at", + "typeName": { + "names": [ + { + "String": { + "str": "pg_catalog" + } + }, + { + "String": { + "str": "timestamptz" + } + } + ], + "typemod": -1 + }, + "is_local": true, + "constraints": [ + { + "Constraint": { + "contype": "CONSTR_NOTNULL" + } + }, + { + "Constraint": { + "contype": "CONSTR_DEFAULT", + "raw_expr": { + "FuncCall": { + "funcname": [ + { + "String": { + "str": "now" + } + } + ] + } + } + } + } + ] + } + } + ], + "oncommit": "ONCOMMIT_NOOP" + } + }, + "stmt_len": 122 + } + }, + { + "RawStmt": { + "stmt": { + "CreateStmt": { + "relation": { + "relname": "users_2", + "inh": true, + "relpersistence": "p" + }, + "tableElts": [ + { + "ColumnDef": { + "colname": "id", + "typeName": { + "names": [ + { + "String": { + "str": "pg_catalog" + } + }, + { + "String": { + "str": "int4" + } + } + ], + "typemod": -1 + }, + "is_local": true, + "constraints": [ + { + "Constraint": { + "contype": "CONSTR_NOTNULL" + } + }, + { + "Constraint": { + "contype": "CONSTR_PRIMARY" + } + } + ] + } + } + ], + "oncommit": "ONCOMMIT_NOOP" + } + }, + "stmt_len": 59 + } + }, + { + "RawStmt": { + "stmt": { + "CreateStmt": { + "relation": { + "relname": "users_3", + "inh": true, + "relpersistence": "p" + }, + "tableElts": [ + { + "ColumnDef": { + "colname": "id", + "typeName": { + "names": [ + { + "String": { + "str": "pg_catalog" + } + }, + { + "String": { + "str": "int4" + } + } + ], + "typemod": -1 + }, + "is_local": true, + "constraints": [ + { + "Constraint": { + "contype": "CONSTR_NOTNULL" + } + }, + { + "Constraint": { + "contype": "CONSTR_PRIMARY" + } + } + ] + } + } + ], + "oncommit": "ONCOMMIT_NOOP" + } + }, + "stmt_len": 59 + } + }, + { + "RawStmt": { + "stmt": { + "CreateFunctionStmt": { + "replace": true, + "funcname": [ + { + "String": { + "str": "update_updated_at_column" + } + } + ], + "returnType": { + "names": [ + { + "String": { + "str": "trigger" + } + } + ], + "typemod": -1 + }, + "options": [ + { + "DefElem": { + "defname": "language", + "arg": { + "String": { + "str": "plpgsql" + } + }, + "defaction": "DEFELEM_UNSPEC" + } + }, + { + "DefElem": { + "defname": "as", + "arg": { + "List": { + "items": [ + { + "String": { + "str": "\nbegin\n new.updated_at = NOW();\n return new;\nend;\n" + } + } + ] + } + }, + "defaction": "DEFELEM_UNSPEC" + } + } + ] + } + }, + "stmt_len": 171 + } + }, + { + "RawStmt": { + "stmt": { + "CreateTrigStmt": { + "trigname": "update_users_3_modtime", + "relation": { + "relname": "users_3", + "inh": true, + "relpersistence": "p" + }, + "funcname": [ + { + "String": { + "str": "update_updated_at_column" + } + } + ], + "row": true, + "timing": 2, + "events": 16 + } + }, + "stmt_len": 148 + } + }, + { + "RawStmt": { + "stmt": { + "CreateTrigStmt": { + "trigname": "update_users_1_modtime", + "relation": { + "relname": "users_1", + "inh": true, + "relpersistence": "p" + }, + "funcname": [ + { + "String": { + "str": "update_updated_at_column" + } + } + ], + "row": true, + "timing": 2, + "events": 16 + } + }, + "stmt_len": 126 + } + }, + { + "RawStmt": { + "stmt": { + "CreateTrigStmt": { + "trigname": "update_users_2_modtime", + "relation": { + "relname": "users_2", + "inh": true, + "relpersistence": "p" + }, + "funcname": [ + { + "String": { + "str": "update_updated_at_column" + } + } + ], + "row": true, + "timing": 2, + "events": 16 + } + }, + "stmt_len": 148 + } + } +] \ No newline at end of file diff --git a/server/src/__tests__/__fixtures__/schemas/correct_schema_1.pgsql b/sample/schemas/correct_schema_1.sql similarity index 100% rename from server/src/__tests__/__fixtures__/schemas/correct_schema_1.pgsql rename to sample/schemas/correct_schema_1.sql diff --git a/server/src/__tests__/__fixtures__/schemas/correct_schema_2.pgsql b/sample/schemas/correct_schema_2.sql similarity index 100% rename from server/src/__tests__/__fixtures__/schemas/correct_schema_2.pgsql rename to sample/schemas/correct_schema_2.sql diff --git a/server/src/__tests__/__fixtures__/schemas/correct_schema_3.pgsql b/sample/schemas/correct_schema_3.sql similarity index 100% rename from server/src/__tests__/__fixtures__/schemas/correct_schema_3.pgsql rename to sample/schemas/correct_schema_3.sql diff --git a/server/src/__tests__/__fixtures__/schemas b/server/src/__tests__/__fixtures__/schemas new file mode 120000 index 0000000..e28eb7d --- /dev/null +++ b/server/src/__tests__/__fixtures__/schemas @@ -0,0 +1 @@ +../../../../sample/schemas/ \ No newline at end of file diff --git a/server/src/__tests__/helpers/file.ts b/server/src/__tests__/helpers/file.ts index 99d4e01..c760985 100644 --- a/server/src/__tests__/helpers/file.ts +++ b/server/src/__tests__/helpers/file.ts @@ -37,7 +37,7 @@ export async function loadSampleFile( } } -function sampleDirPath(): string { +export function sampleDirPath(): string { return path.join(__dirname, "..", "__fixtures__") } diff --git a/server/src/commands/validateWorkspace.ts b/server/src/commands/validateWorkspace.ts index d44b6ca..6d78375 100644 --- a/server/src/commands/validateWorkspace.ts +++ b/server/src/commands/validateWorkspace.ts @@ -64,6 +64,7 @@ export async function validateFile( options.hasDiagnosticRelatedInformationCapability, queryParameterInfo, statements: settings.statements, + plpgsqlCheckSchema: settings.plpgsqlCheckSchema, }, settings, logger, diff --git a/server/src/errors.ts b/server/src/errors.ts index dfce272..662bf26 100644 --- a/server/src/errors.ts +++ b/server/src/errors.ts @@ -71,7 +71,10 @@ export class WorkspaceValidationTargetFilesEmptyError export class MigrationError extends PlpgsqlLanguageServerError { - constructor(public document: TextDocument, message: string) { + public migrationPath: string + + constructor(public document: TextDocument, message: string, migrationPath: string) { super(message) + this.migrationPath = migrationPath } } diff --git a/server/src/postgres/parsers/parseCreateStatements.ts b/server/src/postgres/parsers/parseCreateStatements.ts index 6b85ccd..a198844 100644 --- a/server/src/postgres/parsers/parseCreateStatements.ts +++ b/server/src/postgres/parsers/parseCreateStatements.ts @@ -298,10 +298,10 @@ export function parseIndexCreateStatements( return [] } - const idxname = IndexStmt?.idxname + let idxname = IndexStmt?.idxname if (idxname === undefined) { - throw new ParsedTypeError("IndexStmt.idxname is undefined!") + idxname = "" } const indexNameLocation = findIndexFromBuffer( diff --git a/server/src/postgres/parsers/parseFunctions.ts b/server/src/postgres/parsers/parseFunctions.ts index 77d9231..27cddc0 100644 --- a/server/src/postgres/parsers/parseFunctions.ts +++ b/server/src/postgres/parsers/parseFunctions.ts @@ -12,14 +12,21 @@ export interface FunctionInfo { location: number | undefined, } +export interface TriggerInfo { + functionName: string, + relname: string, + stmtLocation?: number, + stmtLen: number, +} + export async function parseFunctions( uri: URI, queryParameterInfo: QueryParameterInfo | null, logger: Logger, -): Promise { +): Promise<[FunctionInfo[], TriggerInfo[]]> { const fileText = await readFileFromUri(uri) if (fileText === null) { - return [] + return [[], []] } const [sanitizedFileText] = sanitizeFileWithQueryParameters( @@ -28,23 +35,37 @@ export async function parseFunctions( const stmtements = await parseStmtements(uri, sanitizedFileText, logger) if (stmtements === undefined) { - return [] + return [[], []] } - - return stmtements.flatMap( + const functions: FunctionInfo[] = [] + const triggers: TriggerInfo[] = [] + stmtements.forEach( (statement) => { - if (statement?.stmt?.CreateFunctionStmt !== undefined) { + + if (statement?.stmt?.CreateFunctionStmt !== undefined ) { try { - return getCreateFunctions(statement) + functions.push(...getCreateFunctions(statement)) } catch (error: unknown) { logger.error(`ParseFunctionError: ${(error as Error).message} (${uri})`) } } - return [] + if (statement?.stmt?.CreateTrigStmt !== undefined ) { + logger.info(`Statically analyzing trigger: ${JSON.stringify(statement)}`) + try { + triggers.push(...getCreateTriggers(statement)) + } + catch (error: unknown) { + logger.error(`ParseFunctionError: ${(error as Error).message} (${uri})`) + } + } }, ) + + logger.error(JSON.stringify(triggers)) + + return [functions, triggers] } function getCreateFunctions( @@ -83,3 +104,45 @@ function getCreateFunctions( }, ) } + + +function getCreateTriggers( + statement: Statement, +): TriggerInfo[] { + const createTriggerStmt = statement?.stmt?.CreateTrigStmt + if (createTriggerStmt === undefined) { + return [] + } + + const funcname = createTriggerStmt.funcname + if (funcname === undefined) { + throw new ParsedTypeError("createTriggerStmt.funcname is undefined!") + } + let relname = createTriggerStmt.relation?.relname || "" + if (relname === "") { + throw new ParsedTypeError("createTriggerStmt.relation?.relname is undefined!") + } + + const schema = createTriggerStmt.relation?.schemaname + if (schema) { + relname = `${schema}.${relname}` + } + + return funcname.flatMap( + (funcname) => { + const functionName = funcname.String.str + if (functionName === undefined) { + return [] + } + + return [ + { + functionName, + relname, + stmtLocation: statement.stmt_location, + stmtLen: statement.stmt_len, + }, + ] + }, + ) +} diff --git a/server/src/postgres/parsers/statement.ts b/server/src/postgres/parsers/statement.ts index 340b3ef..b0117bd 100644 --- a/server/src/postgres/parsers/statement.ts +++ b/server/src/postgres/parsers/statement.ts @@ -112,6 +112,7 @@ export interface CreateTrigStmt { } export interface CreateTrigStmtRelation { + schemaname?: string relname: string inh: boolean relpersistence: string diff --git a/server/src/postgres/queries/queryFileStaticAnalysis.ts b/server/src/postgres/queries/queryFileStaticAnalysis.ts index 39e8fa1..48cd37c 100644 --- a/server/src/postgres/queries/queryFileStaticAnalysis.ts +++ b/server/src/postgres/queries/queryFileStaticAnalysis.ts @@ -1,13 +1,18 @@ import { Logger, Range, uinteger } from "vscode-languageserver" import { TextDocument } from "vscode-languageserver-textdocument" -import { PostgresPool } from "@/postgres" +import { PostgresClient } from "@/postgres" import { QueryParameterInfo, sanitizeFileWithQueryParameters, } from "@/postgres/parameters" -import { FunctionInfo } from "@/postgres/parsers/parseFunctions" -import { getLineRangeFromBuffer, getTextAllRange } from "@/utilities/text" +import { FunctionInfo, TriggerInfo } from "@/postgres/parsers/parseFunctions" +import { Settings } from "@/settings" +import { DISABLE_STATIC_VALIDATION_RE } from "@/utilities/regex" +import { + getLineRangeFromBuffer, + getRangeFromBuffer, getTextAllRange, +} from "@/utilities/text" export interface StaticAnalysisErrorRow { procedure: string @@ -32,26 +37,25 @@ export interface StaticAnalysisError { export type StaticAnalysisOptions = { isComplete: boolean, queryParameterInfo: QueryParameterInfo | null + plpgsqlCheckSchema?: string + migrations?: Settings["migrations"] } export async function queryFileStaticAnalysis( - pgPool: PostgresPool, + pgClient: PostgresClient, document: TextDocument, functionInfos: FunctionInfo[], + triggerInfos: TriggerInfo[], options: StaticAnalysisOptions, logger: Logger, ): Promise { const errors: StaticAnalysisError[] = [] - const [fileText, parameterNumber] = sanitizeFileWithQueryParameters( + const [fileText] = sanitizeFileWithQueryParameters( document.getText(), options.queryParameterInfo, logger, ) + logger.info(`fileText.length: ${fileText.length}`) - const pgClient = await pgPool.connect() try { - await pgClient.query("BEGIN") - await pgClient.query( - fileText, Array(parameterNumber).fill(null), - ) const extensionCheck = await pgClient.query(` SELECT extname @@ -61,7 +65,10 @@ export async function queryFileStaticAnalysis( extname = 'plpgsql_check' `) + if (extensionCheck.rowCount === 0) { + logger.warn("plpgsql_check is not installed in the database.") + return [] } @@ -91,35 +98,109 @@ export async function queryFileStaticAnalysis( continue } - rows.forEach( - (row) => { - const range = (() => { - return (location === undefined) - ? getTextAllRange(document) - : getLineRangeFromBuffer( - fileText, - location, - row.lineno ? row.lineno - 1 : 0, - ) ?? getTextAllRange(document) - })() - - errors.push({ - level: row.level, range, message: row.message, - }) - }, - ) + extractError(rows, location) } } catch (error: unknown) { + await pgClient.query("ROLLBACK to validated_syntax") + await pgClient.query("BEGIN") if (options.isComplete) { const message = (error as Error).message - logger.error(`StaticAnalysisError: ${message} (${document.uri})`) + logger.error(`StaticAnalysisError (1): ${message} (${document.uri})`) } } - finally { - await pgClient.query("ROLLBACK") - pgClient.release() + + try { + for (const triggerInfo of triggerInfos) { + const { functionName, stmtLocation, relname, stmtLen } = triggerInfo + logger.warn(` + trigger::: + relname: ${relname} + functionName: ${functionName} + stmtLocation: ${stmtLocation}`) + + const result = await pgClient.query( + ` + SELECT + (pcf).functionid::regprocedure AS procedure, + (pcf).lineno AS lineno, + (pcf).statement AS statement, + (pcf).sqlstate AS sqlstate, + (pcf).message AS message, + (pcf).detail AS detail, + (pcf).hint AS hint, + (pcf).level AS level, + (pcf)."position" AS position, + (pcf).query AS query, + (pcf).context AS context + FROM + plpgsql_check_function_tb($1, $2) AS pcf + `, + [functionName, relname], + ) + + const rows: StaticAnalysisErrorRow[] = result.rows + if (rows.length === 0) { + continue + } + + extractError(rows, stmtLocation, stmtLen) + } + } + catch (error: unknown) { + await pgClient.query("ROLLBACK to validated_syntax") + await pgClient.query("BEGIN") + if (options.isComplete) { + const message = (error as Error).message + logger.error(`StaticAnalysisError (2): ${message} (${document.uri})`) + } } return errors + + function extractError( + rows: StaticAnalysisErrorRow[], + location: number | undefined, + stmtLen?: number, + ) { + rows.forEach((row) => { + const range = (() => { + if (location === undefined) { + return getTextAllRange(document) + } + if (stmtLen) { + const stmt = fileText.slice(location + 1, location + 1 + stmtLen) + if (DISABLE_STATIC_VALIDATION_RE + .test(stmt)) { + return + } + + return getRangeFromBuffer( + fileText, + location + 1, + location + 1 + stmtLen, + ) + } + const lineRange = getLineRangeFromBuffer( + fileText, + location, + row.lineno ? row.lineno - 1 : 0, + ) + + if (!lineRange) { + return getTextAllRange(document) + } + + return lineRange + })() + + if (!range) { + return + } + + errors.push({ + level: row.level, range, message: row.message, + }) + }) + } } diff --git a/server/src/postgres/queries/queryFileSyntaxAnalysis.ts b/server/src/postgres/queries/queryFileSyntaxAnalysis.ts index 815f372..ae18ae2 100644 --- a/server/src/postgres/queries/queryFileSyntaxAnalysis.ts +++ b/server/src/postgres/queries/queryFileSyntaxAnalysis.ts @@ -1,30 +1,21 @@ -import fs from "fs/promises" -import glob from "glob-promise" -import path from "path" import { DatabaseError } from "pg" import { Diagnostic, DiagnosticSeverity, Logger, uinteger, } from "vscode-languageserver" import { TextDocument } from "vscode-languageserver-textdocument" -import { MigrationError } from "@/errors" -import { PostgresClient, PostgresPool } from "@/postgres" +import { PostgresClient } from "@/postgres" import { getQueryParameterInfo, QueryParameterInfo, sanitizeFileWithQueryParameters, } from "@/postgres/parameters" -import { MigrationsSettings, Settings, StatementsSettings } from "@/settings" -import { asyncFlatMap } from "@/utilities/functool" +import { Settings, StatementsSettings } from "@/settings" import { neverReach } from "@/utilities/neverReach" +import { + BEGIN_RE, COMMIT_RE, DISABLE_STATEMENT_VALIDATION_RE, ROLLBACK_RE, SQL_COMMENT_RE, +} from "@/utilities/regex" import { getCurrentLineFromIndex, getTextAllRange } from "@/utilities/text" -const SQL_COMMENT_RE = /\/\*[\s\S]*?\*\/|([^:]|^)--.*$/gm -const BEGIN_RE = /^([\s]*begin[\s]*;)/igm -const COMMIT_RE = /^([\s]*commit[\s]*;)/igm -const ROLLBACK_RE = /^([\s]*rollback[\s]*;)/igm - -const DISABLE_STATEMENT_VALIDATION_RE = /^ *-- +plpgsql-language-server:disable *$/m - export type SyntaxAnalysisOptions = { isComplete: boolean queryParameterInfo: QueryParameterInfo | null @@ -32,7 +23,7 @@ export type SyntaxAnalysisOptions = { }; export async function queryFileSyntaxAnalysis( - pgPool: PostgresPool, + pgClient: PostgresClient, document: TextDocument, options: SyntaxAnalysisOptions, settings: Settings, @@ -47,29 +38,6 @@ export async function queryFileSyntaxAnalysis( statementSepRE = new RegExp(`(${options.statements.separatorPattern})`, "g") preparedStatements = documentText.split(statementSepRE) } - const pgClient = await pgPool.connect() - - try { - await pgClient.query("BEGIN") - - if (settings.migrations) { - await runMigration(pgClient, document, settings.migrations, logger) - } - } catch (error: unknown) { - if (error instanceof MigrationError) { - diagnostics.push({ - severity: DiagnosticSeverity.Error, - range: getTextAllRange(document), - message: error.message, - }) - } - - // Restart transaction. - await pgClient.query("ROLLBACK") - await pgClient.query("BEGIN") - } finally { - await pgClient.query("SAVEPOINT migrations") - } const statementNames: string[] = [] for (let i = 0; i < preparedStatements.length; i++) { @@ -100,12 +68,11 @@ export async function queryFileSyntaxAnalysis( currentTextIndex, logger, )) - } finally { - await pgClient.query("ROLLBACK TO migrations") + if (preparedStatements.length > 0) { + await pgClient.query("ROLLBACK TO migrations") + } } } - await pgClient.query("ROLLBACK") - pgClient.release() return diagnostics } @@ -141,104 +108,6 @@ function statementError( } } -async function runMigration( - pgClient: PostgresClient, - document: TextDocument, - migrations: MigrationsSettings, - logger: Logger, -): Promise { - const upMigrationFiles = ( - await asyncFlatMap( - migrations.upFiles, - (filePattern) => glob.promise(filePattern), - )) - .sort((a, b) => a.localeCompare(b, undefined, { numeric: true })) - - const downMigrationFiles = ( - await asyncFlatMap( - migrations.downFiles, - (filePattern) => glob.promise(filePattern), - )) - .sort((a, b) => b.localeCompare(a, undefined, { numeric: true })) - - const postMigrationFiles = ( - await asyncFlatMap( - migrations.postMigrationFiles ?? [], - (filePattern) => glob.promise(filePattern), - )) - .sort((a, b) => a.localeCompare(b, undefined, { numeric: true })) - - const migrationTarget = migrations?.target ?? "up/down" - - if (migrationTarget === "up/down" - && ( - // Check if it is not a migration file. - upMigrationFiles.filter(file => document.uri.endsWith(file)).length - + downMigrationFiles.filter(file => document.uri.endsWith(file)).length === 0 - ) - ) { - return - } - - let shouldContinue = true - - if (shouldContinue) { - shouldContinue = await queryMigrations( - pgClient, document, downMigrationFiles, logger, - ) - } - - if (shouldContinue) { - shouldContinue = await queryMigrations( - pgClient, document, upMigrationFiles, logger, - ) - } - - if (shouldContinue) { - shouldContinue = await queryMigrations( - pgClient, document, postMigrationFiles, logger, - ) - } -} - -async function queryMigrations( - pgClient: PostgresClient, - document: TextDocument, - files: string[], - logger: Logger, -): Promise { - for await (const file of files) { - try { - if (document.uri.endsWith(path.normalize(file))) { - // allow us to revisit and work on any migration/post-migration file - logger.info("Stopping migration execution at the current file") - - return false - } - - logger.info(`Migration ${file}`) - - const migration = (await fs.readFile(file, { encoding: "utf8" })) - .replace(BEGIN_RE, (m) => "-".repeat(m.length)) - .replace(COMMIT_RE, (m) => "-".repeat(m.length)) - .replace(ROLLBACK_RE, (m) => "-".repeat(m.length)) - - await pgClient.query(migration) - } catch (error: unknown) { - const databaseErrorMessage = (error as DatabaseError).message - const filename = path.basename(file) - const errorMessage = - `Stopping migration execution at ${filename}: ${databaseErrorMessage}` - - logger.error(errorMessage) - - throw new MigrationError(document, errorMessage) - } - } - - return true -} - function queryStatement( document: TextDocument, statement: string, diff --git a/server/src/services/migrations.ts b/server/src/services/migrations.ts new file mode 100644 index 0000000..db5dc7a --- /dev/null +++ b/server/src/services/migrations.ts @@ -0,0 +1,114 @@ +import fs from "fs/promises" +import glob from "glob-promise" +import path from "path" +import { DatabaseError } from "pg" +import { + Logger, +} from "vscode-languageserver" +import { TextDocument } from "vscode-languageserver-textdocument" + +import { MigrationError } from "@/errors" +import { PostgresClient } from "@/postgres" +import { MigrationsSettings } from "@/settings" +import { asyncFlatMap } from "@/utilities/functool" +import { BEGIN_RE, COMMIT_RE, ROLLBACK_RE } from "@/utilities/regex" + + +export async function runMigration( + pgClient: PostgresClient, + document: TextDocument, + migrations: MigrationsSettings, + logger: Logger, +): Promise { + const upMigrationFiles = ( + await asyncFlatMap( + migrations.upFiles, + (filePattern) => glob.promise(filePattern), + )) + .sort((a, b) => a.localeCompare(b, undefined, { numeric: true })) + .map(file => path.normalize(file)) + + const downMigrationFiles = ( + await asyncFlatMap( + migrations.downFiles, + (filePattern) => glob.promise(filePattern), + )) + .sort((a, b) => b.localeCompare(a, undefined, { numeric: true })) + .map(file => path.normalize(file)) + + const postMigrationFiles = ( + await asyncFlatMap( + migrations.postMigrationFiles ?? [], + (filePattern) => glob.promise(filePattern), + )) + .sort((a, b) => a.localeCompare(b, undefined, { numeric: true })) + .map(file => path.normalize(file)) + + const migrationTarget = migrations?.target ?? "up/down" + const currentFileIsMigration = + upMigrationFiles.filter(file => document.uri.endsWith(file)).length + + downMigrationFiles.filter(file => document.uri.endsWith(file)).length !== 0 + + if (migrationTarget === "up/down" && !currentFileIsMigration) { + return currentFileIsMigration + } + + let shouldContinue = true + + if (shouldContinue) { + shouldContinue = await queryMigrations( + pgClient, document, downMigrationFiles, logger, + ) + } + + if (shouldContinue) { + shouldContinue = await queryMigrations( + pgClient, document, upMigrationFiles, logger, + ) + } + + if (shouldContinue) { + shouldContinue = await queryMigrations( + pgClient, document, postMigrationFiles, logger, + ) + } + + return currentFileIsMigration +} + +async function queryMigrations( + pgClient: PostgresClient, + document: TextDocument, + files: string[], + logger: Logger, +): Promise { + for await (const file of files) { + try { + if (document.uri.endsWith(file)) { + // allow us to revisit and work on any migration/post-migration file + logger.info("Stopping migration execution at the current file") + + return false + } + + logger.info(`Migration ${file}`) + + const migration = (await fs.readFile(file, { encoding: "utf8" })) + .replace(BEGIN_RE, (m) => "-".repeat(m.length)) + .replace(COMMIT_RE, (m) => "-".repeat(m.length)) + .replace(ROLLBACK_RE, (m) => "-".repeat(m.length)) + + await pgClient.query(migration) + } catch (error: unknown) { + const errorMessage = (error as DatabaseError).message + + logger.error( + `Stopping migration execution at ${file}: ${errorMessage}`, + ) + + throw new MigrationError(document, errorMessage, file) + } + } + + return true +} diff --git a/server/src/services/validation.test.ts b/server/src/services/validation.test.ts index db3cc09..07a10fa 100644 --- a/server/src/services/validation.test.ts +++ b/server/src/services/validation.test.ts @@ -1,6 +1,10 @@ +import glob from "glob-promise" +import path from "path" import { Diagnostic, DiagnosticSeverity, Range } from "vscode-languageserver" -import { DEFAULT_LOAD_FILE_OPTIONS, LoadFileOptions } from "@/__tests__/helpers/file" +import { + DEFAULT_LOAD_FILE_OPTIONS, LoadFileOptions, sampleDirPath, +} from "@/__tests__/helpers/file" import { RecordLogger } from "@/__tests__/helpers/logger" import { setupTestServer } from "@/__tests__/helpers/server" import { SettingsBuilder } from "@/__tests__/helpers/settings" @@ -68,6 +72,41 @@ describe("Validate Tests", () => { ]) }) + it("TRIGGER on inexistent field", async () => { + const diagnostics = await validateSampleFile( + "definitions/trigger/static_error_trigger_column_does_not_exist.pgsql", + ) + + expect(diagnostics).toStrictEqual([ + { + severity: DiagnosticSeverity.Error, + message: 'record "new" has no field "updated_at"', + range: Range.create(24, 0, 27, 47), + }, { + severity: DiagnosticSeverity.Error, + message: 'record "new" has no field "updated_at"', + range: Range.create(32, 0, 35, 47), + }, + ]) + }) + + it("static analysis disabled on invalid statement", async () => { + const diagnostics = await validateSampleFile( + "definitions/trigger/static_error_disabled.pgsql", + ) + + if (!diagnostics) { + throw new Error("") + } + if (diagnostics?.length === 0) { + throw new Error("") + } + + expect(diagnostics).toHaveLength(1) + expect(diagnostics[0].message) + .toContain("record \"new\" has no field \"updated_at\"") + }) + it("FUNCTION column does not exists", async () => { const diagnostics = await validateSampleFile( "definitions/function/syntax_error_function_column_does_not_exist.pgsql", @@ -91,6 +130,18 @@ describe("Validate Tests", () => { expect(diagnostics).toStrictEqual([]) }) + + it("correct schemas", async () => { + const schemas = (await glob.promise(path.join(sampleDirPath(), "schemas/*.sql"))) + .map(file => path.relative(sampleDirPath(), file)) + + schemas.forEach(async (schema) => { + const diagnostics = await validateSampleFile(schema) + + expect(diagnostics).toStrictEqual([]) + }) + }) + it("Syntax error query", async () => { const diagnostics = await validateSampleFile( "queries/syntax_error_query_with_language_server_disable_comment.pgsql", diff --git a/server/src/services/validation.ts b/server/src/services/validation.ts index 2fd40c0..eb34174 100644 --- a/server/src/services/validation.ts +++ b/server/src/services/validation.ts @@ -1,18 +1,24 @@ +import path from "path" import { Diagnostic, DiagnosticSeverity, Logger } from "vscode-languageserver" import { TextDocument } from "vscode-languageserver-textdocument" -import { PostgresPool } from "@/postgres" +import { MigrationError } from "@/errors" +import { PostgresClient, PostgresPool } from "@/postgres" import { QueryParameterInfo } from "@/postgres/parameters" import { parseFunctions } from "@/postgres/parsers/parseFunctions" import { queryFileStaticAnalysis } from "@/postgres/queries/queryFileStaticAnalysis" import { queryFileSyntaxAnalysis } from "@/postgres/queries/queryFileSyntaxAnalysis" +import { runMigration } from "@/services/migrations" import { Settings, StatementsSettings } from "@/settings" +import { getTextAllRange } from "@/utilities/text" type ValidateTextDocumentOptions = { isComplete: boolean, hasDiagnosticRelatedInformationCapability: boolean, queryParameterInfo: QueryParameterInfo | null, statements?: StatementsSettings, + plpgsqlCheckSchema?: string, + migrations?: Settings["migrations"] } export async function validateTextDocument( @@ -22,29 +28,93 @@ export async function validateTextDocument( settings: Settings, logger: Logger, ): Promise { - let diagnostics: Diagnostic[] = [] - diagnostics = await validateSyntaxAnalysis( - pgPool, + const diagnostics: Diagnostic[] = [] + + const pgClient = await pgPool.connect() + + await setupEnvironment(pgClient, options) + + try { + await pgClient.query("BEGIN") + + if (settings.migrations) { + await runMigration( + pgClient, + document, + settings.migrations, + logger, + ) + } + } catch (error: unknown) { + if (error instanceof MigrationError) { + diagnostics.push({ + severity: DiagnosticSeverity.Error, + range: getTextAllRange(document), + message: `${error.migrationPath}: ${error.message}`, + relatedInformation: [ + { + location: { + uri: path.resolve(error.migrationPath), + range: getTextAllRange(document), + }, + message: error.message, + }, + ], + }) + } + + // Restart transaction. + await pgClient.query("ROLLBACK") + await pgClient.query("BEGIN") + } finally { + await pgClient.query("SAVEPOINT migrations") + } + + const syntaxDiagnostics = await validateSyntaxAnalysis( + pgClient, document, options, settings, logger, ) + diagnostics.push(...syntaxDiagnostics) - // TODO static analysis for statements - // if (diagnostics.filter(d => d.severity === DiagnosticSeverity.Error).length === 0) { - if (diagnostics.length === 0) { - diagnostics = await validateStaticAnalysis( - pgPool, + if (diagnostics.filter(d => d.severity === DiagnosticSeverity.Error).length === 0) { + await pgClient.query("SAVEPOINT validated_syntax") + const staticDiagnostics = await validateStaticAnalysis( + pgClient, document, options, logger, ) + diagnostics.push(...staticDiagnostics) } + await pgClient.query("ROLLBACK") + pgClient.release() + return diagnostics } +async function setupEnvironment( + pgClient: PostgresClient, + options: ValidateTextDocumentOptions, +) { + const plpgsqlCheckSchema = options.plpgsqlCheckSchema + // outside transaction + if (plpgsqlCheckSchema) { + await pgClient.query(` + SELECT + set_config( + 'search_path', + current_setting('search_path') || ',${plpgsqlCheckSchema}', + false + ) + WHERE current_setting('search_path') !~ '(^|,)${plpgsqlCheckSchema}(,|$)' + `) + } +} + export async function isCorrectFileValidation( pgPool: PostgresPool, document: TextDocument, @@ -70,14 +140,14 @@ export async function isCorrectFileValidation( } async function validateSyntaxAnalysis( - pgPool: PostgresPool, + pgClient: PostgresClient, document: TextDocument, options: ValidateTextDocumentOptions, settings: Settings, logger: Logger, ): Promise { return await queryFileSyntaxAnalysis( - pgPool, + pgClient, document, options, settings, @@ -86,18 +156,26 @@ async function validateSyntaxAnalysis( } async function validateStaticAnalysis( - pgPool: PostgresPool, + pgClient: PostgresClient, document: TextDocument, options: ValidateTextDocumentOptions, logger: Logger, ): Promise { + const [functionInfos, triggerInfos] = await parseFunctions( + document.uri, + options.queryParameterInfo, + logger, + ) const errors = await queryFileStaticAnalysis( - pgPool, + pgClient, document, - await parseFunctions(document.uri, options.queryParameterInfo, logger), + functionInfos, + triggerInfos, { isComplete: options.isComplete, queryParameterInfo: options.queryParameterInfo, + plpgsqlCheckSchema: options.plpgsqlCheckSchema, + migrations: options.migrations, }, logger, ) diff --git a/server/src/settings.ts b/server/src/settings.ts index b6f798e..6e0917d 100644 --- a/server/src/settings.ts +++ b/server/src/settings.ts @@ -6,6 +6,7 @@ export interface Settings { password?: string; definitionFiles: string[]; defaultSchema: string; + plpgsqlCheckSchema?: string; queryParameterPattern: string | string[]; keywordQueryParameterPattern?: string | string[]; enableExecuteFileQueryCommand: boolean; @@ -41,6 +42,7 @@ export const DEFAULT_SETTINGS: Settings = { password: undefined, definitionFiles: ["**/*.psql", "**/*.pgsql"], defaultSchema: "public", + plpgsqlCheckSchema: undefined, queryParameterPattern: /\$[1-9][0-9]*/.source, keywordQueryParameterPattern: undefined, enableExecuteFileQueryCommand: true, diff --git a/server/src/utilities/regex.ts b/server/src/utilities/regex.ts index 8630ae9..1894365 100644 --- a/server/src/utilities/regex.ts +++ b/server/src/utilities/regex.ts @@ -1,3 +1,14 @@ +/* eslint-disable max-len */ + export function escapeRegex(string: string): string { return string.replace(/[-/\\^$*+?.()|[\]{}]/g, "\\$&") } + + +export const SQL_COMMENT_RE = /\/\*[\s\S]*?\*\/|([^:]|^)--.*$/gm +export const BEGIN_RE = /^([\s]*begin[\s]*;)/igm +export const COMMIT_RE = /^([\s]*commit[\s]*;)/igm +export const ROLLBACK_RE = /^([\s]*rollback[\s]*;)/igm + +export const DISABLE_STATEMENT_VALIDATION_RE = /^ *-- +plpgsql-language-server:disable *$/m +export const DISABLE_STATIC_VALIDATION_RE = /^ *-- +plpgsql-language-server:disable-static *$/m diff --git a/server/src/utilities/text.ts b/server/src/utilities/text.ts index ce2880a..28802ae 100644 --- a/server/src/utilities/text.ts +++ b/server/src/utilities/text.ts @@ -88,7 +88,7 @@ export function getLineRangeFromBuffer( fileText: string, index: uinteger, offsetLine: uinteger = 0, ): Range | undefined { const textLines = Buffer.from(fileText) - .slice(0, index) + .subarray(0, index) .toString() .split("\n")