From de854ef4051f368a3489635aa315050608bc5154 Mon Sep 17 00:00:00 2001 From: ddh-5230 <2023478471@qq.com> Date: Fri, 13 Sep 2024 20:17:02 +0800 Subject: [PATCH] add test case --- catalog/identifier.go | 34 ++++++++----- catalog/table.go | 8 ++- transpiler/converter.go | 96 ++++++++++++++++++++++++++++++++++++ transpiler/converter_test.go | 69 ++++++++++++++++++++++++++ 4 files changed, 192 insertions(+), 15 deletions(-) diff --git a/catalog/identifier.go b/catalog/identifier.go index f7a428d6..4a8a0da2 100644 --- a/catalog/identifier.go +++ b/catalog/identifier.go @@ -2,6 +2,10 @@ package catalog import ( "strings" + + "github.com/apecloud/myduckserver/transpiler" + + "github.com/dolthub/vitess/go/vt/sqlparser" ) func FullSchemaName(catalog, schema string) string { @@ -41,21 +45,25 @@ func DecodeIndexName(encodedName string) (string, string) { return parts[0], parts[1] } -// DecodeCreateindex extracts column names from a SQL string, Only consider single-column indexes or multi-column indexes -// TODO: using sqlparser to parse columns name, now identifiers(index name, table name, column name) cannot include parentheses. -// such as CREATE INDEX "idx((())hello" ON db.T((t.a)); will cause an error -func DecodeCreateindex(createIndexSQL string) []string { - leftParen := strings.Index(createIndexSQL, "(") - rightParen := strings.Index(createIndexSQL, ")") - if leftParen != -1 && rightParen != -1 { - content := createIndexSQL[leftParen+1 : rightParen] - columns := strings.Split(content, ",") - for i, col := range columns { - columns[i] = strings.TrimSpace(col) +func DecodeCreateindex(createIndexSQL string) ([]string, error) { + + denormalizedQuery := transpiler.DenormalizeStrings(createIndexSQL) + stmt, err := sqlparser.Parse(denormalizedQuery) + + if err != nil { + return nil, err + } + + switch stmt := stmt.(type) { + case *sqlparser.AlterTable: + var columnNames []string + for _, column := range stmt.Statements[0].IndexSpec.Columns { + columnNames = append(columnNames, column.Column.String()) } - return columns + return columnNames, nil } - return []string{} + + return nil, nil } func QuoteIdentifierANSI(identifier string) string { diff --git a/catalog/table.go b/catalog/table.go index dc4e2ecd..2a456867 100644 --- a/catalog/table.go +++ b/catalog/table.go @@ -337,10 +337,11 @@ func (t *Table) CreateIndex(ctx *sql.Context, indexDef sql.IndexDef) error { // Construct the SQL statement for creating the index var sqlsBuilder strings.Builder + sqlsBuilder.WriteString(fmt.Sprintf(`USE %s; `, t.db.catalog)) sqlsBuilder.WriteString(fmt.Sprintf(`CREATE %s INDEX "%s" ON %s (%s)`, unique, EncodeIndexName(t.name, indexDef.Name), - FullTableName(t.db.catalog, t.db.name, t.name), + FullTableName("", t.db.name, t.name), strings.Join(columns, ", "))) // Add the index comment if provided @@ -439,7 +440,10 @@ func (t *Table) GetIndexes(ctx *sql.Context) ([]sql.Index, error) { } _, indexName := DecodeIndexName(encodedIndexName) - columnNames := DecodeCreateindex(createIndexSQL) + columnNames, err := DecodeCreateindex(createIndexSQL) + if err != nil { + return nil, ErrDuckDB.New(err) + } for _, columnName := range columnNames { if columnInfo, exists := columnsInfoMap[columnName]; exists { diff --git a/transpiler/converter.go b/transpiler/converter.go index cb35758a..1cfc8b90 100644 --- a/transpiler/converter.go +++ b/transpiler/converter.go @@ -167,3 +167,99 @@ func NormalizeStrings(q string) string { return normalized.String() } + +// normalizes a query string to convert any Postgres syntax to MySQL syntax +func DenormalizeStrings(q string) string { + state := notInString + lastCharWasBackslash := false + denormalized := strings.Builder{} + + for _, c := range q { + switch state { + case notInString: + switch c { + case singleQuote: + state = inSingleQuote + denormalized.WriteRune(singleQuote) + case doubleQuote: + state = inDoubleQuote + denormalized.WriteRune(backtick) + default: + denormalized.WriteRune(c) + } + case inSingleQuote: + switch c { + case backslash: + if lastCharWasBackslash { + denormalized.WriteRune(c) + } else { + lastCharWasBackslash = !lastCharWasBackslash + } + case singleQuote: + if lastCharWasBackslash { + denormalized.WriteRune(c) + lastCharWasBackslash = false + } else { + state = maybeEndSingleQuote + } + default: + lastCharWasBackslash = false + denormalized.WriteRune(c) + } + case maybeEndSingleQuote: + switch c { + case singleQuote: + state = inSingleQuote + denormalized.WriteRune(singleQuote) + denormalized.WriteRune(singleQuote) + default: + state = notInString + denormalized.WriteRune(singleQuote) + denormalized.WriteRune(c) + } + case inDoubleQuote: + switch c { + case backslash: + if lastCharWasBackslash { + denormalized.WriteRune(c) + } else { + lastCharWasBackslash = !lastCharWasBackslash + } + case doubleQuote: + if lastCharWasBackslash { + denormalized.WriteRune(c) + lastCharWasBackslash = !lastCharWasBackslash + } else { + state = maybeEndDoubleQuote + } + case backtick: + denormalized.WriteRune(backtick) + denormalized.WriteRune(backtick) + default: + lastCharWasBackslash = false + denormalized.WriteRune(c) + } + case maybeEndDoubleQuote: + switch c { + case doubleQuote: + state = inDoubleQuote + denormalized.WriteRune(doubleQuote) + default: + state = notInString + denormalized.WriteRune(backtick) + denormalized.WriteRune(c) + } + default: + panic("unknown state") + } + } + switch state { + case maybeEndSingleQuote: + denormalized.WriteRune(singleQuote) + case maybeEndDoubleQuote: + denormalized.WriteRune(backtick) + default: // do nothing + + } + return denormalized.String() +} diff --git a/transpiler/converter_test.go b/transpiler/converter_test.go index dba0621a..7c303d58 100644 --- a/transpiler/converter_test.go +++ b/transpiler/converter_test.go @@ -106,3 +106,72 @@ func TestNormalizeStrings(t *testing.T) { }) } } + +// Test converting Postgres strings to MySQL strings +func TestDeNormalizeStrings(t *testing.T) { + type test struct { + input string + expected string + } + tests := []test{ + { + input: `SELECT 'foo' FROM "bar"`, + expected: "SELECT 'foo' FROM `bar`", + }, + { + input: `SELECT 'foo'`, + expected: `SELECT 'foo'`, + }, + { + input: `SELECT 'fo"o'`, + expected: `SELECT 'fo"o'`, + }, + { + input: `SELECT 'fo''o'`, + expected: `SELECT 'fo''o'`, + }, + { + input: `SELECT 'fo"o'`, + expected: `SELECT 'fo"o'`, + }, + { + input: `SELECT 'fo''o'`, + expected: "SELECT 'fo''o'", + }, + { + input: `SELECT 'fo''''o'`, + expected: "SELECT 'fo''''o'", + }, + { + input: `SELECT "foo" FROM "bar"`, + expected: "SELECT `foo` FROM `bar`", + }, + { + input: `SELECT 'foo' FROM "bar"`, + expected: "SELECT 'foo' FROM `bar`", + }, + { + input: `SELECT 'foo' from "bar" where "bar"."baz" = 'qux'`, + expected: "SELECT 'foo' from `bar` where `bar`.`baz` = 'qux'", + }, + { + input: `SELECT "fo""o" FROM "bar"`, + expected: "SELECT `fo\"o` FROM `bar`", + }, + { + input: "SELECT \"fo`o\" FROM \"bar\"", + expected: "SELECT `fo``o` FROM `bar`", + }, + { + input: `SELECT 'fo""o' FROM "bar"`, + expected: "SELECT 'fo\"\"o' FROM `bar`", + }, + } + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + actual := DenormalizeStrings(test.input) + require.Equal(t, test.expected, actual) + }) + } +}