diff --git a/ext/store/maxcompute/sanitizer.go b/ext/store/maxcompute/sanitizer.go new file mode 100644 index 0000000000..960cecc5ff --- /dev/null +++ b/ext/store/maxcompute/sanitizer.go @@ -0,0 +1,53 @@ +package maxcompute + +import "github.com/aliyun/aliyun-odps-go-sdk/odps/common" + +var ( + // reserved keywords https://www.alibabacloud.com/help/en/maxcompute/user-guide/reserved-words-and-keywords + reservedKeywords = []string{ + "add", "after", "all", "alter", "analyze", "and", "archive", "array", "as", "asc", + "before", "between", "bigint", "binary", "blob", "boolean", "both", "decimal", + "bucket", "buckets", "by", "cascade", "case", "cast", "cfile", "change", "cluster", + "clustered", "clusterstatus", "collection", "column", "columns", "comment", "compute", + "concatenate", "continue", "create", "cross", "current", "cursor", "data", "database", + "databases", "date", "datetime", "dbproperties", "deferred", "delete", "delimited", + "desc", "describe", "directory", "disable", "distinct", "distribute", "double", "drop", + "else", "enable", "end", "except", "escaped", "exclusive", "exists", "explain", "export", + "extended", "external", "false", "fetch", "fields", "fileformat", "first", "float", + "following", "format", "formatted", "from", "full", "function", "functions", "grant", + "group", "having", "hold_ddltime", "idxproperties", "if", "import", "in", "index", + "indexes", "inpath", "inputdriver", "inputformat", "insert", "int", "intersect", "into", + "is", "items", "join", "keys", "lateral", "left", "lifecycle", "like", "limit", "lines", + "load", "local", "location", "lock", "locks", "long", "map", "mapjoin", "materialized", + "minus", "msck", "not", "no_drop", "null", "of", "offline", "offset", "on", "option", + "or", "order", "out", "outer", "outputdriver", "outputformat", "over", "overwrite", + "partition", "partitioned", "partitionproperties", "partitions", "percent", "plus", + "preceding", "preserve", "procedure", "purge", "range", "rcfile", "read", "readonly", + "reads", "rebuild", "recordreader", "recordwriter", "reduce", "regexp", "rename", + "repair", "replace", "restrict", "revoke", "right", "rlike", "row", "rows", "schema", + "schemas", "select", "semi", "sequencefile", "serde", "serdeproperties", "set", "shared", + "show", "show_database", "smallint", "sort", "sorted", "ssl", "statistics", "status", + "stored", "streamtable", "string", "struct", "table", "tables", "tablesample", + "tblproperties", "temporary", "terminated", "textfile", "then", "timestamp", "tinyint", + "to", "touch", "transform", "trigger", "true", "type", "unarchive", "unbounded", "undo", + "union", "uniontype", "uniquejoin", "unlock", "unsigned", "update", "use", "using", + "utc", "utc_timestamp", "view", "when", "where", "while", "div", + } + + reservedKeywordsMap map[string]struct{} +) + +func init() { + reservedKeywordsMap = make(map[string]struct{}) + for _, keyword := range reservedKeywords { + reservedKeywordsMap[keyword] = struct{}{} + } +} + +func SafeKeyword(keyword string) string { + if _, exists := reservedKeywordsMap[keyword]; exists { + return common.QuoteRef(keyword) + } + + return keyword +} diff --git a/ext/store/maxcompute/sanitizer_test.go b/ext/store/maxcompute/sanitizer_test.go new file mode 100644 index 0000000000..0dbcab40d9 --- /dev/null +++ b/ext/store/maxcompute/sanitizer_test.go @@ -0,0 +1,28 @@ +package maxcompute + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSanitizer(t *testing.T) { + t.Run("returns safe keyword", func(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + {"select", "`select`"}, + {"from", "`from`"}, + {"case", "`case`"}, + {"customer_name", "customer_name"}, + {"other", "other"}, + {"table", "`table`"}, + } + + for _, tc := range testCases { + result := SafeKeyword(tc.input) + assert.Equal(t, tc.expected, result) + } + }) +} diff --git a/ext/store/maxcompute/table.go b/ext/store/maxcompute/table.go index e353e25859..8967ea119f 100644 --- a/ext/store/maxcompute/table.go +++ b/ext/store/maxcompute/table.go @@ -158,14 +158,14 @@ func populateColumns(t *Table, schemaBuilder *tableschema.SchemaBuilder) error { func generateUpdateQuery(incoming, existing tableschema.TableSchema, schemaName string) ([]string, error) { var sqlTasks []string if incoming.Comment != existing.Comment { - sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set comment %s;", schemaName, existing.TableName, common.QuoteString(incoming.Comment))) + sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set comment %s;", SafeKeyword(schemaName), SafeKeyword(existing.TableName), common.QuoteString(incoming.Comment))) } if incoming.Lifecycle != existing.Lifecycle { if incoming.Lifecycle <= 0 && existing.Lifecycle >= 0 { - sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s disable lifecycle;", schemaName, existing.TableName)) + sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s disable lifecycle;", SafeKeyword(schemaName), SafeKeyword(existing.TableName))) } else if incoming.Lifecycle > 0 { - sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set lifecycle %d;", schemaName, existing.TableName, incoming.Lifecycle)) + sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set lifecycle %d;", SafeKeyword(schemaName), SafeKeyword(existing.TableName), incoming.Lifecycle)) } } @@ -257,7 +257,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR if incomingColumnRecord.columnValue.NotNull { return fmt.Errorf("unable to add new required column") } - segment := fmt.Sprintf("if not exists %s %s", incomingColumnRecord.columnStructure, incomingColumnRecord.columnValue.Type.Name()) + segment := fmt.Sprintf("if not exists %s %s", SafeKeyword(incomingColumnRecord.columnStructure), incomingColumnRecord.columnValue.Type.Name()) if incomingColumnRecord.columnValue.Comment != "" { segment += fmt.Sprintf(" comment %s", common.QuoteString(incomingColumnRecord.columnValue.Comment)) } @@ -268,7 +268,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR if !columnFound.NotNull && incomingColumnRecord.columnValue.NotNull { return fmt.Errorf("unable to modify column mode from nullable to required") } else if columnFound.NotNull && !incomingColumnRecord.columnValue.NotNull { - *sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s null;", schemaName, tableName, columnFound.Name)) + *sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s null;", SafeKeyword(schemaName), SafeKeyword(tableName), SafeKeyword(columnFound.Name))) } if columnFound.Type.ID() != incomingColumnRecord.columnValue.Type.ID() { @@ -277,7 +277,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR if incomingColumnRecord.columnValue.Comment != columnFound.Comment { *sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s %s %s comment %s;", - schemaName, tableName, columnFound.Name, incomingColumnRecord.columnValue.Name, columnFound.Type, common.QuoteString(incomingColumnRecord.columnValue.Comment))) + SafeKeyword(schemaName), SafeKeyword(tableName), SafeKeyword(columnFound.Name), SafeKeyword(incomingColumnRecord.columnValue.Name), columnFound.Type, common.QuoteString(incomingColumnRecord.columnValue.Comment))) } delete(existing, incomingColumnRecord.columnStructure) } @@ -290,7 +290,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR if len(columnAddition) > 0 { for _, segment := range columnAddition { - addColumnQuery := fmt.Sprintf("alter table %s.%s add column ", schemaName, tableName) + segment + ";" + addColumnQuery := fmt.Sprintf("alter table %s.%s add column ", SafeKeyword(schemaName), SafeKeyword(tableName)) + segment + ";" *sqlTasks = append(*sqlTasks, addColumnQuery) } }