Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions ext/store/maxcompute/sanitizer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package maxcompute

import "github.com/aliyun/aliyun-odps-go-sdk/odps/common"

var (
// reserved keywords https://www.alibabacloud.com/help/en/maxcompute/user-guide/reserved-words-and-keywords
reservedKeywords = []string{
"add", "after", "all", "alter", "analyze", "and", "archive", "array", "as", "asc",
"before", "between", "bigint", "binary", "blob", "boolean", "both", "decimal",
"bucket", "buckets", "by", "cascade", "case", "cast", "cfile", "change", "cluster",
"clustered", "clusterstatus", "collection", "column", "columns", "comment", "compute",
"concatenate", "continue", "create", "cross", "current", "cursor", "data", "database",
"databases", "date", "datetime", "dbproperties", "deferred", "delete", "delimited",
"desc", "describe", "directory", "disable", "distinct", "distribute", "double", "drop",
"else", "enable", "end", "except", "escaped", "exclusive", "exists", "explain", "export",
"extended", "external", "false", "fetch", "fields", "fileformat", "first", "float",
"following", "format", "formatted", "from", "full", "function", "functions", "grant",
"group", "having", "hold_ddltime", "idxproperties", "if", "import", "in", "index",
"indexes", "inpath", "inputdriver", "inputformat", "insert", "int", "intersect", "into",
"is", "items", "join", "keys", "lateral", "left", "lifecycle", "like", "limit", "lines",
"load", "local", "location", "lock", "locks", "long", "map", "mapjoin", "materialized",
"minus", "msck", "not", "no_drop", "null", "of", "offline", "offset", "on", "option",
"or", "order", "out", "outer", "outputdriver", "outputformat", "over", "overwrite",
"partition", "partitioned", "partitionproperties", "partitions", "percent", "plus",
"preceding", "preserve", "procedure", "purge", "range", "rcfile", "read", "readonly",
"reads", "rebuild", "recordreader", "recordwriter", "reduce", "regexp", "rename",
"repair", "replace", "restrict", "revoke", "right", "rlike", "row", "rows", "schema",
"schemas", "select", "semi", "sequencefile", "serde", "serdeproperties", "set", "shared",
"show", "show_database", "smallint", "sort", "sorted", "ssl", "statistics", "status",
"stored", "streamtable", "string", "struct", "table", "tables", "tablesample",
"tblproperties", "temporary", "terminated", "textfile", "then", "timestamp", "tinyint",
"to", "touch", "transform", "trigger", "true", "type", "unarchive", "unbounded", "undo",
"union", "uniontype", "uniquejoin", "unlock", "unsigned", "update", "use", "using",
"utc", "utc_timestamp", "view", "when", "where", "while", "div",
}

reservedKeywordsMap map[string]struct{}
)

func init() {

Check failure on line 40 in ext/store/maxcompute/sanitizer.go

View workflow job for this annotation

GitHub Actions / lint

don't use `init` function (gochecknoinits)
reservedKeywordsMap = make(map[string]struct{})
for _, keyword := range reservedKeywords {
reservedKeywordsMap[keyword] = struct{}{}
}
}

func SafeKeyword(keyword string) string {
if _, exists := reservedKeywordsMap[keyword]; exists {
return common.QuoteRef(keyword)
}

return keyword
}
28 changes: 28 additions & 0 deletions ext/store/maxcompute/sanitizer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package maxcompute

Check failure on line 1 in ext/store/maxcompute/sanitizer_test.go

View workflow job for this annotation

GitHub Actions / lint

package should be `maxcompute_test` instead of `maxcompute` (testpackage)

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestSanitizer(t *testing.T) {
t.Run("returns safe keyword", func(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{"select", "`select`"},
{"from", "`from`"},
{"case", "`case`"},
{"customer_name", "customer_name"},
{"other", "other"},
{"table", "`table`"},
}

for _, tc := range testCases {
result := SafeKeyword(tc.input)
assert.Equal(t, tc.expected, result)
}
})
}
14 changes: 7 additions & 7 deletions ext/store/maxcompute/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ func populateColumns(t *Table, schemaBuilder *tableschema.SchemaBuilder) error {
func generateUpdateQuery(incoming, existing tableschema.TableSchema, schemaName string) ([]string, error) {
var sqlTasks []string
if incoming.Comment != existing.Comment {
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set comment %s;", schemaName, existing.TableName, common.QuoteString(incoming.Comment)))
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set comment %s;", SafeKeyword(schemaName), SafeKeyword(existing.TableName), common.QuoteString(incoming.Comment)))
}

if incoming.Lifecycle != existing.Lifecycle {
if incoming.Lifecycle <= 0 && existing.Lifecycle >= 0 {
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s disable lifecycle;", schemaName, existing.TableName))
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s disable lifecycle;", SafeKeyword(schemaName), SafeKeyword(existing.TableName)))
} else if incoming.Lifecycle > 0 {
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set lifecycle %d;", schemaName, existing.TableName, incoming.Lifecycle))
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set lifecycle %d;", SafeKeyword(schemaName), SafeKeyword(existing.TableName), incoming.Lifecycle))
}
}

Expand Down Expand Up @@ -257,7 +257,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR
if incomingColumnRecord.columnValue.NotNull {
return fmt.Errorf("unable to add new required column")
}
segment := fmt.Sprintf("if not exists %s %s", incomingColumnRecord.columnStructure, incomingColumnRecord.columnValue.Type.Name())
segment := fmt.Sprintf("if not exists %s %s", SafeKeyword(incomingColumnRecord.columnStructure), incomingColumnRecord.columnValue.Type.Name())
if incomingColumnRecord.columnValue.Comment != "" {
segment += fmt.Sprintf(" comment %s", common.QuoteString(incomingColumnRecord.columnValue.Comment))
}
Expand All @@ -268,7 +268,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR
if !columnFound.NotNull && incomingColumnRecord.columnValue.NotNull {
return fmt.Errorf("unable to modify column mode from nullable to required")
} else if columnFound.NotNull && !incomingColumnRecord.columnValue.NotNull {
*sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s null;", schemaName, tableName, columnFound.Name))
*sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s null;", SafeKeyword(schemaName), SafeKeyword(tableName), SafeKeyword(columnFound.Name)))
}

if columnFound.Type.ID() != incomingColumnRecord.columnValue.Type.ID() {
Expand All @@ -277,7 +277,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR

if incomingColumnRecord.columnValue.Comment != columnFound.Comment {
*sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s %s %s comment %s;",
schemaName, tableName, columnFound.Name, incomingColumnRecord.columnValue.Name, columnFound.Type, common.QuoteString(incomingColumnRecord.columnValue.Comment)))
SafeKeyword(schemaName), SafeKeyword(tableName), SafeKeyword(columnFound.Name), SafeKeyword(incomingColumnRecord.columnValue.Name), columnFound.Type, common.QuoteString(incomingColumnRecord.columnValue.Comment)))
}
delete(existing, incomingColumnRecord.columnStructure)
}
Expand All @@ -290,7 +290,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR

if len(columnAddition) > 0 {
for _, segment := range columnAddition {
addColumnQuery := fmt.Sprintf("alter table %s.%s add column ", schemaName, tableName) + segment + ";"
addColumnQuery := fmt.Sprintf("alter table %s.%s add column ", SafeKeyword(schemaName), SafeKeyword(tableName)) + segment + ";"
*sqlTasks = append(*sqlTasks, addColumnQuery)
}
}
Expand Down
Loading