Skip to content

Commit

Permalink
add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
ddh-5230 committed Sep 13, 2024
1 parent c739821 commit de854ef
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 15 deletions.
34 changes: 21 additions & 13 deletions catalog/identifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package catalog

import (
"strings"

"github.com/apecloud/myduckserver/transpiler"

"github.com/dolthub/vitess/go/vt/sqlparser"
)

func FullSchemaName(catalog, schema string) string {
Expand Down Expand Up @@ -41,21 +45,25 @@ func DecodeIndexName(encodedName string) (string, string) {
return parts[0], parts[1]
}

// DecodeCreateindex extracts column names from a SQL string, Only consider single-column indexes or multi-column indexes
// TODO: using sqlparser to parse columns name, now identifiers(index name, table name, column name) cannot include parentheses.
// such as CREATE INDEX "idx((())hello" ON db.T((t.a)); will cause an error
func DecodeCreateindex(createIndexSQL string) []string {
leftParen := strings.Index(createIndexSQL, "(")
rightParen := strings.Index(createIndexSQL, ")")
if leftParen != -1 && rightParen != -1 {
content := createIndexSQL[leftParen+1 : rightParen]
columns := strings.Split(content, ",")
for i, col := range columns {
columns[i] = strings.TrimSpace(col)
func DecodeCreateindex(createIndexSQL string) ([]string, error) {

denormalizedQuery := transpiler.DenormalizeStrings(createIndexSQL)
stmt, err := sqlparser.Parse(denormalizedQuery)

if err != nil {
return nil, err
}

switch stmt := stmt.(type) {
case *sqlparser.AlterTable:
var columnNames []string
for _, column := range stmt.Statements[0].IndexSpec.Columns {
columnNames = append(columnNames, column.Column.String())
}
return columns
return columnNames, nil
}
return []string{}

return nil, nil
}

func QuoteIdentifierANSI(identifier string) string {
Expand Down
8 changes: 6 additions & 2 deletions catalog/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,11 @@ func (t *Table) CreateIndex(ctx *sql.Context, indexDef sql.IndexDef) error {

// Construct the SQL statement for creating the index
var sqlsBuilder strings.Builder
sqlsBuilder.WriteString(fmt.Sprintf(`USE %s; `, t.db.catalog))
sqlsBuilder.WriteString(fmt.Sprintf(`CREATE %s INDEX "%s" ON %s (%s)`,
unique,
EncodeIndexName(t.name, indexDef.Name),
FullTableName(t.db.catalog, t.db.name, t.name),
FullTableName("", t.db.name, t.name),
strings.Join(columns, ", ")))

// Add the index comment if provided
Expand Down Expand Up @@ -439,7 +440,10 @@ func (t *Table) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
}

_, indexName := DecodeIndexName(encodedIndexName)
columnNames := DecodeCreateindex(createIndexSQL)
columnNames, err := DecodeCreateindex(createIndexSQL)
if err != nil {
return nil, ErrDuckDB.New(err)
}

for _, columnName := range columnNames {
if columnInfo, exists := columnsInfoMap[columnName]; exists {
Expand Down
96 changes: 96 additions & 0 deletions transpiler/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,99 @@ func NormalizeStrings(q string) string {

return normalized.String()
}

// normalizes a query string to convert any Postgres syntax to MySQL syntax
func DenormalizeStrings(q string) string {
state := notInString
lastCharWasBackslash := false
denormalized := strings.Builder{}

for _, c := range q {
switch state {
case notInString:
switch c {
case singleQuote:
state = inSingleQuote
denormalized.WriteRune(singleQuote)
case doubleQuote:
state = inDoubleQuote
denormalized.WriteRune(backtick)
default:
denormalized.WriteRune(c)
}
case inSingleQuote:
switch c {
case backslash:
if lastCharWasBackslash {
denormalized.WriteRune(c)
} else {
lastCharWasBackslash = !lastCharWasBackslash
}
case singleQuote:
if lastCharWasBackslash {
denormalized.WriteRune(c)
lastCharWasBackslash = false
} else {
state = maybeEndSingleQuote
}
default:
lastCharWasBackslash = false
denormalized.WriteRune(c)
}
case maybeEndSingleQuote:
switch c {
case singleQuote:
state = inSingleQuote
denormalized.WriteRune(singleQuote)
denormalized.WriteRune(singleQuote)
default:
state = notInString
denormalized.WriteRune(singleQuote)
denormalized.WriteRune(c)
}
case inDoubleQuote:
switch c {
case backslash:
if lastCharWasBackslash {
denormalized.WriteRune(c)
} else {
lastCharWasBackslash = !lastCharWasBackslash
}
case doubleQuote:
if lastCharWasBackslash {
denormalized.WriteRune(c)
lastCharWasBackslash = !lastCharWasBackslash
} else {
state = maybeEndDoubleQuote
}
case backtick:
denormalized.WriteRune(backtick)
denormalized.WriteRune(backtick)
default:
lastCharWasBackslash = false
denormalized.WriteRune(c)
}
case maybeEndDoubleQuote:
switch c {
case doubleQuote:
state = inDoubleQuote
denormalized.WriteRune(doubleQuote)
default:
state = notInString
denormalized.WriteRune(backtick)
denormalized.WriteRune(c)
}
default:
panic("unknown state")
}
}
switch state {
case maybeEndSingleQuote:
denormalized.WriteRune(singleQuote)
case maybeEndDoubleQuote:
denormalized.WriteRune(backtick)
default: // do nothing

}
return denormalized.String()
}
69 changes: 69 additions & 0 deletions transpiler/converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,72 @@ func TestNormalizeStrings(t *testing.T) {
})
}
}

// Test converting Postgres strings to MySQL strings
func TestDeNormalizeStrings(t *testing.T) {
type test struct {
input string
expected string
}
tests := []test{
{
input: `SELECT 'foo' FROM "bar"`,
expected: "SELECT 'foo' FROM `bar`",
},
{
input: `SELECT 'foo'`,
expected: `SELECT 'foo'`,
},
{
input: `SELECT 'fo"o'`,
expected: `SELECT 'fo"o'`,
},
{
input: `SELECT 'fo''o'`,
expected: `SELECT 'fo''o'`,
},
{
input: `SELECT 'fo"o'`,
expected: `SELECT 'fo"o'`,
},
{
input: `SELECT 'fo''o'`,
expected: "SELECT 'fo''o'",
},
{
input: `SELECT 'fo''''o'`,
expected: "SELECT 'fo''''o'",
},
{
input: `SELECT "foo" FROM "bar"`,
expected: "SELECT `foo` FROM `bar`",
},
{
input: `SELECT 'foo' FROM "bar"`,
expected: "SELECT 'foo' FROM `bar`",
},
{
input: `SELECT 'foo' from "bar" where "bar"."baz" = 'qux'`,
expected: "SELECT 'foo' from `bar` where `bar`.`baz` = 'qux'",
},
{
input: `SELECT "fo""o" FROM "bar"`,
expected: "SELECT `fo\"o` FROM `bar`",
},
{
input: "SELECT \"fo`o\" FROM \"bar\"",
expected: "SELECT `fo``o` FROM `bar`",
},
{
input: `SELECT 'fo""o' FROM "bar"`,
expected: "SELECT 'fo\"\"o' FROM `bar`",
},
}

for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
actual := DenormalizeStrings(test.input)
require.Equal(t, test.expected, actual)
})
}
}

0 comments on commit de854ef

Please sign in to comment.