diff --git a/.circleci/config.yml b/.circleci/config.yml index fa729da..fd41d93 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,6 +7,11 @@ jobs: environment: GO111MODULE: "on" - image: circleci/mysql:5.7 + - image: circleci/postgres:12.2 + environment: + POSTGRES_USER: root + POSTGRES_DB: rapidash + POSTGRES_HOST_AUTH_METHOD: trust - image: memcached:1.5 - image: redis:5.0 steps: @@ -25,11 +30,15 @@ jobs: golangci-lint run - run: name: Wait for DB - command: dockerize -wait tcp://127.0.0.1:3306 -timeout 30s + command: dockerize -wait tcp://127.0.0.1:3306 -wait tcp://127.0.0.1:5432 -timeout 30s + - run: + name: Run unit tests for postgres + command: | + RAPIDASH_DB_DRIVER=postgres go test -v ./... - run: name: Run unit tests and measure coverage command: | - go test -v -coverprofile=coverage.out ./... + RAPIDASH_DB_DRIVER=mysql go test -v -coverprofile=coverage.out ./... bash <(curl -s https://codecov.io/bash) -P ${CIRCLE_PULL_REQUEST##*/} workflows: version: 2 diff --git a/benchmark_test.go b/benchmark_test.go index a6a5f0f..d48d139 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -202,7 +202,7 @@ func BenchmarkGetByPrimaryKey_RapidashWorst(b *testing.B) { if err != nil { panic(err) } - builder := NewQueryBuilder("a").Eq("id", id) + builder := NewQueryBuilder("a", driver.Adapter).Eq("id", id) var a A if err := tx.FindByQueryBuilder(builder, &a); err != nil { panic(err) @@ -243,7 +243,7 @@ func BenchmarkGetByPrimaryKey_RapidashBest(b *testing.B) { if err != nil { panic(err) } - builder := NewQueryBuilder("a").Eq("id", id) + builder := NewQueryBuilder("a", driver.Adapter).Eq("id", id) var a A if err := tx.FindByQueryBuilder(builder, &a); err != nil { panic(err) @@ -261,7 +261,7 @@ func BenchmarkGetByPrimaryKey_RapidashBest(b *testing.B) { if err != nil { panic(err) } - builder := NewQueryBuilder("a").Eq("id", id) + builder := NewQueryBuilder("a", driver.Adapter).Eq("id", id) var a A if err := tx.FindByQueryBuilder(builder, &a); err != nil { panic(err) @@ -462,7 +462,7 @@ func BenchmarkUpdateByPrimaryKey_RapidashWorst(b *testing.B) { if err != nil { panic(err) } - builder := NewQueryBuilder("a").Eq("id", id) + builder := NewQueryBuilder("a", driver.Adapter).Eq("id", id) if err := tx.UpdateByQueryBuilder(builder, map[string]interface{}{ "name": "bench2", }); err != nil { @@ -503,7 +503,7 @@ func BenchmarkUpdateByPrimaryKey_RapidashBest(b *testing.B) { if err != nil { panic(err) } - builder := NewQueryBuilder("a").Eq("id", id) + builder := NewQueryBuilder("a", driver.Adapter).Eq("id", id) var a A if err := tx.FindByQueryBuilder(builder, &a); err != nil { panic(err) @@ -524,7 +524,7 @@ func BenchmarkUpdateByPrimaryKey_RapidashBest(b *testing.B) { if err != nil { panic(err) } - builder := NewQueryBuilder("a").Eq("id", id) + builder := NewQueryBuilder("a", driver.Adapter).Eq("id", id) if err := tx.UpdateByQueryBuilder(builder, map[string]interface{}{ "name": "bench2", }); err != nil { @@ -622,7 +622,7 @@ func BenchmarkDeleteByPrimaryKey_Rapidash(b *testing.B) { if err != nil { panic(err) } - builder := NewQueryBuilder("a").Eq("id", id) + builder := NewQueryBuilder("a", driver.Adapter).Eq("id", id) if err := tx.DeleteByQueryBuilder(builder); err != nil { panic(err) } diff --git a/config_test.go b/config_test.go index cb91120..6827dc5 100644 --- a/config_test.go +++ b/config_test.go @@ -10,10 +10,10 @@ import ( func TestConfig(t *testing.T) { cfg, err := NewConfig("testdata/cache.yml") NoError(t, err) - cache, err := New(cfg.Options()...) + cache, err := New(append([]OptionFunc{DatabaseAdapter(driver.DBType)}, cfg.Options()...)...) NoError(t, err) NoError(t, cache.Flush()) - conn, err := sql.Open("mysql", "root:@tcp(localhost:3306)/rapidash?parseTime=true") + conn, err := sql.Open(driver.Name, driver.Source) NoError(t, err) NoError(t, cache.WarmUp(conn, userLoginType(), false)) t.Run("create new records", func(t *testing.T) { @@ -43,7 +43,7 @@ func TestConfig(t *testing.T) { tx, err := cache.Begin(txConn) NoError(t, err) for i := 1001; i <= 1005; i++ { - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(i)). Eq("user_session_id", uint64(i)) var foundUserLogin UserLogin @@ -173,7 +173,7 @@ func TestConfig(t *testing.T) { t.Run("should retrieve each handled data", func(t *testing.T) { key := fmt.Sprintf("key_%d", time.Now().UnixNano()) - var resultFirst int + var resultFirst int var resultSecond int expectFirst := 1 expectSecond := 2 @@ -213,7 +213,7 @@ func TestConfig(t *testing.T) { t.Run("implicit cache control", func(t *testing.T) { t.Run("should retrieve each handled data", func(t *testing.T) { key := fmt.Sprintf("key_%d", time.Now().UnixNano()) - var resultFirst int + var resultFirst int var resultSecond int expectFirst := 1 expectSecond := 2 diff --git a/database/database.go b/database/database.go new file mode 100644 index 0000000..01995ff --- /dev/null +++ b/database/database.go @@ -0,0 +1,124 @@ +package database + +import ( + "database/sql" + + "go.knocknote.io/rapidash/database/mysql" + "go.knocknote.io/rapidash/database/postgres" +) + +type DBType int + +const ( + None DBType = iota + MySQL + Postgres +) + +const ( + mysqlPlugin = "mysql" + postgresPlugin = "postgres" +) + +type Database interface { + database + Placeholder(int) string + Placeholders(int, int) string +} + +type database interface { + TableDDL(*sql.DB, string) (string, error) + Quote(string) string + SupportLastInsertID() bool +} + +type QueryHelper struct { + count int + adapter *adapter +} + +func (qh *QueryHelper) DBType() DBType { + return qh.adapter.DBType +} + +func (qh *QueryHelper) Placeholder() string { + qh.count++ + return qh.adapter.Placeholder(qh.count) +} + +func (qh *QueryHelper) Placeholders(n int) string { + start := qh.count + 1 + end := start + n - 1 + qh.count += n + return qh.adapter.Placeholders(start, end) +} + +func (qh *QueryHelper) Quote(str string) string { + return qh.adapter.Quote(str) +} + +func (qh *QueryHelper) SupportLastInsertID() bool { + return qh.adapter.SupportLastInsertID() +} + +func (qh *QueryHelper) ClearCount() { + qh.count = 0 +} + +type Adapter interface { + database + QueryHelper() *QueryHelper +} + +type adapter struct { + DBType DBType + Database +} + +func (d *adapter) QueryHelper() *QueryHelper { + return &QueryHelper{ + count: 0, + adapter: d, + } +} + +func NewAdapter() *adapter { + drivers := sql.Drivers() + if len(drivers) == 0 { + return nil + } + dbType := toDBType(drivers[0]) + return &adapter{ + DBType: dbType, + Database: NewDatabase(dbType), + } +} + +func NewAdapterWithDBType(dbType DBType) *adapter { + return &adapter{ + DBType: dbType, + Database: NewDatabase(dbType), + } +} + +func NewDatabase(dbType DBType) Database { + switch dbType { + case MySQL: + return &mysql.MySQL{} + case Postgres: + return &postgres.Postgres{} + } + return nil +} + +func toDBType(pluginName string) DBType { + switch pluginName { + case mysqlPlugin: + return MySQL + case postgresPlugin: + return Postgres + } + return None +} + +var _ Adapter = (*adapter)(nil) diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go new file mode 100644 index 0000000..8203081 --- /dev/null +++ b/database/mysql/mysql.go @@ -0,0 +1,46 @@ +package mysql + +import ( + "database/sql" + "fmt" + "strings" + + "golang.org/x/xerrors" +) + +type MySQL struct{} + +func (ms *MySQL) TableDDL(conn *sql.DB, tableName string) (string, error) { + var ( + tbl string + ddl string + ) + if err := conn.QueryRow(fmt.Sprintf("SHOW CREATE TABLE `%s`", tableName)).Scan(&tbl, &ddl); err != nil { + return "", xerrors.Errorf("failed to execute 'SHOW CREATE TABLE `%s`': %w", tableName, err) + } + return ddl, nil +} + +func (ms *MySQL) Placeholder(_ int) string { + return "?" +} + +func (ms *MySQL) Placeholders(start, end int) string { + sb := &strings.Builder{} + sb.Grow((len(ms.Placeholder(0)) + 1) * (end - start + 1)) + for i := start; i <= end; i++ { + sb.WriteString(ms.Placeholder(0)) + if i < end { + sb.WriteString(",") + } + } + return sb.String() +} + +func (ms *MySQL) Quote(s string) string { + return fmt.Sprintf("`%s`", s) +} + +func (ms *MySQL) SupportLastInsertID() bool { + return true +} diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go new file mode 100644 index 0000000..3cb5f19 --- /dev/null +++ b/database/postgres/postgres.go @@ -0,0 +1,224 @@ +package postgres + +import ( + "database/sql" + "fmt" + "strings" + + "golang.org/x/xerrors" +) + +type Postgres struct{} + +func (p *Postgres) Placeholder(idx int) string { + return fmt.Sprintf("$%d", idx) +} + +func (p *Postgres) Placeholders(start, end int) string { + sb := &strings.Builder{} + sb.Grow((len(p.Placeholder(end)) + 1) * (end - start + 1)) + for i := start; i <= end; i++ { + sb.WriteString(p.Placeholder(i)) + if i < end { + sb.WriteString(",") + } + } + return sb.String() +} + +func (p *Postgres) Quote(s string) string { + return fmt.Sprintf(`"%s"`, s) +} + +func (p *Postgres) SupportLastInsertID() bool { + return false +} + +func (p *Postgres) TableDDL(conn *sql.DB, table string) (string, error) { + cols, err := p.getColumns(conn, table) + if err != nil { + return "", xerrors.Errorf("failed to get columns: %w", err) + } + primaryKeyDef, err := p.getPrimaryKeyDef(conn, table) + if err != nil { + return "", xerrors.Errorf("failed to get primary key def: %w", err) + } + indexDefs, err := p.getIndexDefs(conn, table) + if err != nil { + return "", xerrors.Errorf("failed to get index defs: %w", err) + } + return p.buildDDL(table, cols, primaryKeyDef, indexDefs), nil +} + +func (p *Postgres) buildDDL(table string, columns []*column, primaryKeyDef string, indexDefs []string) string { + builder := &strings.Builder{} + builder.WriteString(fmt.Sprintf("CREATE TABLE public.%s (\n", table)) + for i, col := range columns { + builder.WriteString(indent) + builder.WriteString(fmt.Sprintf("%s %s", col.Name, col.DataType())) + if !col.Nullable { + builder.WriteString(" NOT NULL") + } + if i < len(columns)-1 || primaryKeyDef != "" || len(indexDefs) > 0 { + builder.WriteString(",\n") + } else { + builder.WriteString("\n") + } + } + if primaryKeyDef != "" { + builder.WriteString(primaryKeyDef) + if len(indexDefs) > 0 { + builder.WriteString(",\n") + } else { + builder.WriteString("\n") + } + } + for idx, def := range indexDefs { + defTxt := strings.Split(def, " ") + for _, v := range defTxt { + if v == "UNIQUE" { + builder.WriteString(fmt.Sprintf("%s%s ", indent, v)) + continue + } + if v == "INDEX" { + builder.WriteString(fmt.Sprintf("%s ", "KEY")) + continue + } + if strings.Contains(v, "(") { + builder.WriteString(v) + continue + } + if strings.Contains(v, ")") { + builder.WriteString(v) + continue + } + } + if idx < len(indexDefs)-1 { + builder.WriteString(",\n") + } else { + builder.WriteString("\n") + } + } + builder.WriteString(");") + return builder.String() +} + +func (p *Postgres) getColumns(conn *sql.DB, table string) ([]*column, error) { + query := "SELECT column_name, data_type, is_nullable FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name=$1;" + rows, err := conn.Query(query, table) + if err != nil { + return nil, xerrors.Errorf("failed to exec query %s: %w", query, err) + } + defer rows.Close() + + var cols []*column + for rows.Next() { + var colName, nullable, dataType string + if err := rows.Scan(&colName, &dataType, &nullable); err != nil { + return nil, xerrors.Errorf("failed to scan index key def: %w", err) + } + cols = append(cols, &column{ + Name: strings.Trim(colName, `" `), + dataType: dataType, + Nullable: nullable == "YES", + }) + } + return cols, nil +} + +func (p *Postgres) getIndexDefs(conn *sql.DB, table string) ([]string, error) { + query := "SELECT indexName, indexdef FROM pg_indexes WHERE tablename=$1" + rows, err := conn.Query(query, table) + if err != nil { + if err == sql.ErrNoRows { + return []string{}, nil + } + return nil, xerrors.Errorf("failed to exec query %s: %w", query, err) + } + defer rows.Close() + + var indexes []string + for rows.Next() { + var indexName, indexdef string + if err := rows.Scan(&indexName, &indexdef); err != nil { + return nil, xerrors.Errorf("failed to scan index key def: %w", err) + } + indexName = strings.Trim(indexName, `" `) + if strings.HasSuffix(indexName, "_pkey") { + continue + } + indexes = append(indexes, indexdef) + } + return indexes, nil +} + +func (p *Postgres) getPrimaryKeyDef(conn *sql.DB, table string) (string, error) { + query := `SELECT kcu.column_name FROM information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name + WHERE constraint_type = 'PRIMARY KEY' AND tc.table_name=$1` + rows, err := conn.Query(query, table) + if err != nil { + if err == sql.ErrNoRows { + return "", nil + } + return "", xerrors.Errorf("failed to exec query %s: %w", query, err) + } + defer rows.Close() + + var columnNames []string + for rows.Next() { + var columnName string + if err := rows.Scan(&columnName); err != nil { + return "", xerrors.Errorf("failed to scan primary key def: %w", err) + } + columnNames = append(columnNames, columnName) + } + if len(columnNames) == 0 { + return "", nil + } + return fmt.Sprintf("%sPRIMARY KEY (%s)", indent, strings.Join(columnNames, ",")), nil +} + +type column struct { + Name string + dataType string + Nullable bool +} + +const ( + smallint = "smallint" + smallserial = "smallserial" + integer = "integer" + serial = "serial" + bigint = "bigint" + bigserial = "bigserial" + timestampWithoutTimeZone = "timestamp without time zone" + timestampWithTimeZone = "timestamp with time zone" + timestamp = "timestamp" + timeWithoutTimeZone = "time without time zone" + timeWithTimeZone = "time with time zone" + time = "time" + userDefined = "USER-DEFINED" + char = "char" + varchar = "varchar" + characterVarying = "character varying" + + indent = " " +) + +// Rapidash gets DDL from database to get the index(including unique key, primary key) information. +// Therefore, instead of getting the strict DDL, the data type is processed so that at least this information can be parsed. +func (c *column) DataType() string { + switch c.dataType { + case smallint, integer, bigint, smallserial, serial, bigserial: + return c.dataType + case timestampWithoutTimeZone, timestampWithTimeZone: + return timestamp + case timeWithoutTimeZone, timeWithTimeZone: + return time + case userDefined, characterVarying, varchar: + return char + default: + return c.dataType + } +} diff --git a/first_level_cache.go b/first_level_cache.go index 0b4cb90..5a25085 100644 --- a/first_level_cache.go +++ b/first_level_cache.go @@ -8,6 +8,7 @@ import ( "sync" "github.com/knocknote/vitess-sqlparser/sqlparser" + "go.knocknote.io/rapidash/database" "golang.org/x/xerrors" ) @@ -37,13 +38,15 @@ type FirstLevelCache struct { findAllValue *StructSliceValue primaryKey string valueFactory *ValueFactory + adapter database.Adapter } -func NewFirstLevelCache(s *Struct) *FirstLevelCache { +func NewFirstLevelCache(s *Struct, adapter database.Adapter) *FirstLevelCache { return &FirstLevelCache{ typ: s, indexTrees: map[string]*BTree{}, valueFactory: NewValueFactory(), + adapter: adapter, } } @@ -87,12 +90,9 @@ func (c *FirstLevelCache) WarmUp(conn *sql.DB) (e error) { } func (c *FirstLevelCache) showCreateTable(conn *sql.DB) (string, error) { - var ( - tbl string - ddl string - ) - if err := conn.QueryRow(fmt.Sprintf("SHOW CREATE TABLE `%s`", c.typ.tableName)).Scan(&tbl, &ddl); err != nil { - return "", xerrors.Errorf("failed to execute 'SHOW CREATE TABLE `%s`': %w", c.typ.tableName, err) + ddl, err := c.adapter.TableDDL(conn, c.typ.tableName) + if err != nil { + return "", xerrors.Errorf("failed to get ddl for %s: %w", c.typ.tableName) } return ddl, nil } @@ -101,7 +101,7 @@ func (c *FirstLevelCache) loadAll(conn *sql.DB) (*sql.Rows, error) { columns := c.typ.Columns() escapedColumns := make([]string, len(columns)) for idx, column := range columns { - escapedColumns[idx] = fmt.Sprintf("`%s`", column) + escapedColumns[idx] = c.adapter.Quote(column) } query := fmt.Sprintf("SELECT %s FROM %s", strings.Join(escapedColumns, ","), c.typ.tableName) rows, err := conn.Query(query) diff --git a/first_level_cache_test.go b/first_level_cache_test.go index 13e7eb6..be66456 100644 --- a/first_level_cache_test.go +++ b/first_level_cache_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - _ "github.com/go-sql-driver/mysql" + "go.knocknote.io/rapidash/database" "golang.org/x/xerrors" ) @@ -89,17 +89,17 @@ func emptyType() *Struct { } func TestPK(t *testing.T) { - flc := NewFirstLevelCache(eventType()) + flc := NewFirstLevelCache(eventType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) var event Event NoError(t, flc.FindByPrimaryKey(NewUint64Value(uint64(1)), &event)) } func TestINQuery(t *testing.T) { - flc := NewFirstLevelCache(eventType()) + flc := NewFirstLevelCache(eventType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) t.Run("in query", func(t *testing.T) { - builder := NewQueryBuilder("events").In("id", []uint64{1, 2, 3, 4, 5}) + builder := NewQueryBuilder("events", driver.Adapter).In("id", []uint64{1, 2, 3, 4, 5}) var events EventSlice NoError(t, flc.FindByQueryBuilder(builder, &events)) if len(events) != 5 { @@ -107,7 +107,7 @@ func TestINQuery(t *testing.T) { } }) t.Run("in query by duplicated values", func(t *testing.T) { - builder := NewQueryBuilder("events").In("id", []uint64{1, 2, 3, 4, 5, 1, 2, 3, 4, 5}) + builder := NewQueryBuilder("events", driver.Adapter).In("id", []uint64{1, 2, 3, 4, 5, 1, 2, 3, 4, 5}) var events EventSlice NoError(t, flc.FindByQueryBuilder(builder, &events)) if len(events) != 5 { @@ -117,7 +117,7 @@ func TestINQuery(t *testing.T) { } func TestFindAll(t *testing.T) { - flc := NewFirstLevelCache(eventType()) + flc := NewFirstLevelCache(eventType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) var events EventSlice NoError(t, flc.FindAll(&events)) @@ -127,9 +127,9 @@ func TestFindAll(t *testing.T) { } func TestComplicatedQuery(t *testing.T) { - flc := NewFirstLevelCache(eventType()) + flc := NewFirstLevelCache(eventType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Eq("event_id", uint64(1)). Gte("start_week", uint8(12)). Lte("end_week", uint8(24)). @@ -146,7 +146,7 @@ func TestNEQQuery(t *testing.T) { NoError(t, err) t.Run("index column", func(t *testing.T) { { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Eq("term", "daytime"). Gte("start_week", uint8(1)). Neq("end_week", uint8(12)) @@ -155,7 +155,7 @@ func TestNEQQuery(t *testing.T) { Equal(t, len(events), 3000) } { - builder := NewQueryBuilder("events").Neq("id", uint64(1)) + builder := NewQueryBuilder("events", driver.Adapter).Neq("id", uint64(1)) var events EventSlice NoError(t, tx.FindByQueryBuilder(builder, &events)) Equal(t, len(events), 3999) @@ -163,13 +163,13 @@ func TestNEQQuery(t *testing.T) { }) t.Run("not index column", func(t *testing.T) { { - builder := NewQueryBuilder("events").Neq("end_week", uint8(12)) + builder := NewQueryBuilder("events", driver.Adapter).Neq("end_week", uint8(12)) var events EventSlice NoError(t, tx.FindByQueryBuilder(builder, &events)) Equal(t, len(events), 3000) } { - builder := NewQueryBuilder("events").Lte("end_week", uint8(100)).Neq("end_week", uint8(12)) + builder := NewQueryBuilder("events", driver.Adapter).Lte("end_week", uint8(100)).Neq("end_week", uint8(12)) var events EventSlice NoError(t, tx.FindByQueryBuilder(builder, &events)) Equal(t, len(events), 3000) @@ -179,10 +179,10 @@ func TestNEQQuery(t *testing.T) { } func TestGteAndLteQuery(t *testing.T) { - flc := NewFirstLevelCache(eventType()) + flc := NewFirstLevelCache(eventType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) t.Run("primary key column", func(t *testing.T) { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Gte("id", uint64(1)). Lte("id", uint64(5)) var events EventSlice @@ -192,14 +192,14 @@ func TestGteAndLteQuery(t *testing.T) { t.Run("index column", func(t *testing.T) { { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Gte("event_id", uint64(900)) var events EventSlice NoError(t, flc.FindByQueryBuilder(builder, &events)) Equal(t, len(events), 404) } { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Lte("event_id", uint64(1000)). Gte("event_id", uint64(900)). Eq("start_week", uint8(1)) @@ -208,7 +208,7 @@ func TestGteAndLteQuery(t *testing.T) { Equal(t, len(events), 101) } { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Lte("event_id", uint64(1000)). Gte("event_id", uint64(900)). Eq("event_id", uint64(1)) @@ -217,7 +217,7 @@ func TestGteAndLteQuery(t *testing.T) { Equal(t, len(events), 0) } { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). In("term", []string{"morning", "daytime"}). Gte("start_week", uint8(12)). Lte("start_week", uint8(25)) @@ -229,7 +229,7 @@ func TestGteAndLteQuery(t *testing.T) { t.Run("not index column", func(t *testing.T) { { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Gte("start_week", uint8(12)). Lte("start_week", uint8(25)) var events EventSlice @@ -238,7 +238,7 @@ func TestGteAndLteQuery(t *testing.T) { } { now := time.Now() - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Lte("updated_at", now). Gte("updated_at", now.Add(time.Hour*24*7)) var events EventSlice @@ -249,10 +249,10 @@ func TestGteAndLteQuery(t *testing.T) { } func TestGtAndLtQuery(t *testing.T) { - flc := NewFirstLevelCache(eventType()) + flc := NewFirstLevelCache(eventType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) t.Run("primary key column", func(t *testing.T) { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Gt("id", uint64(0)). Lt("id", uint64(6)) var events EventSlice @@ -262,14 +262,14 @@ func TestGtAndLtQuery(t *testing.T) { t.Run("index column", func(t *testing.T) { { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Gt("event_id", uint64(900)) var events EventSlice NoError(t, flc.FindByQueryBuilder(builder, &events)) Equal(t, len(events), 400) } { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Gt("event_id", uint64(900)). Lt("event_id", uint64(1000)) var events EventSlice @@ -277,7 +277,7 @@ func TestGtAndLtQuery(t *testing.T) { Equal(t, len(events), 396) } { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Lt("event_id", uint64(1000)). Gt("event_id", uint64(900)). Eq("start_week", uint8(1)) @@ -289,7 +289,7 @@ func TestGtAndLtQuery(t *testing.T) { t.Run("not index column", func(t *testing.T) { { - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Gt("start_week", uint8(12)). Lt("start_week", uint8(25)) var events EventSlice @@ -298,7 +298,7 @@ func TestGtAndLtQuery(t *testing.T) { } { now := time.Now() - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Lt("updated_at", now). Gt("updated_at", now.Add(time.Hour*24*7)) var events EventSlice @@ -309,9 +309,9 @@ func TestGtAndLtQuery(t *testing.T) { } func TestOrderQuery(t *testing.T) { - flc := NewFirstLevelCache(eventType()) + flc := NewFirstLevelCache(eventType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Eq("event_id", uint64(1)). Gte("start_week", uint8(12)). OrderDesc("id"). @@ -331,9 +331,9 @@ func TestOrderQuery(t *testing.T) { } func TestCountQueryFLC(t *testing.T) { - flc := NewFirstLevelCache(eventType()) + flc := NewFirstLevelCache(eventType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) - builder := NewQueryBuilder("events"). + builder := NewQueryBuilder("events", driver.Adapter). Eq("event_id", uint64(1)) count, err := flc.CountByQueryBuilder(builder) NoError(t, err) @@ -352,7 +352,7 @@ func TestPtrType(t *testing.T) { tx, err := cache.Begin() NoError(t, err) t.Run("EQ", func(t *testing.T) { - builder := NewQueryBuilder("ptr").Eq("intptr", 1).Eq("int8ptr", int8(2)). + builder := NewQueryBuilder("ptr", driver.Adapter).Eq("intptr", 1).Eq("int8ptr", int8(2)). Eq("int16ptr", int16(3)).Eq("int32ptr", int32(4)).Eq("int64ptr", int64(5)). Eq("uintptr", uint(6)).Eq("uint8ptr", uint8(7)).Eq("uint16ptr", uint16(8)). Eq("uint32ptr", uint32(9)).Eq("uint64ptr", uint64(10)).Eq("float32ptr", float32(1.23)). @@ -372,7 +372,7 @@ func TestPtrType(t *testing.T) { }) t.Run("NEQ", func(t *testing.T) { - builder := NewQueryBuilder("ptr").Neq("intptr", 1).Neq("int8ptr", int8(2)). + builder := NewQueryBuilder("ptr", driver.Adapter).Neq("intptr", 1).Neq("int8ptr", int8(2)). Neq("int16ptr", int16(3)).Neq("int32ptr", int32(4)).Neq("int64ptr", int64(5)). Neq("uintptr", uint(6)).Neq("uint8ptr", uint8(7)).Neq("uint16ptr", uint16(8)). Neq("uint32ptr", uint32(9)).Neq("uint64ptr", uint64(10)).Neq("float32ptr", float32(1.23)). @@ -384,23 +384,23 @@ func TestPtrType(t *testing.T) { }) t.Run("LTE AND GTE", func(t *testing.T) { builders := []*QueryBuilder{ - NewQueryBuilder("ptr").Gte("intptr", 1).Lte("intptr", 1), - NewQueryBuilder("ptr").Gte("int8ptr", int8(2)).Lte("int8ptr", int8(2)), - NewQueryBuilder("ptr").Gte("int16ptr", int16(3)).Lte("int16ptr", int16(3)), - NewQueryBuilder("ptr").Gte("int32ptr", int32(4)).Lte("int32ptr", int32(4)), - NewQueryBuilder("ptr").Gte("int64ptr", int64(5)).Lte("int64ptr", int64(5)), - NewQueryBuilder("ptr").Gte("uintptr", uint(6)).Lte("uintptr", uint(6)), - NewQueryBuilder("ptr").Gte("uint8ptr", uint8(7)).Lte("uint8ptr", uint8(7)), - NewQueryBuilder("ptr").Gte("uint16ptr", uint16(8)).Lte("uint16ptr", uint16(8)), - NewQueryBuilder("ptr").Gte("uint32ptr", uint32(9)).Lte("uint32ptr", uint32(9)), - NewQueryBuilder("ptr").Gte("uint64ptr", uint64(10)).Lte("uint64ptr", uint64(10)), - NewQueryBuilder("ptr").Gte("float32ptr", float32(1.23)).Lte("float32ptr", float32(1.23)), - NewQueryBuilder("ptr").Gte("float64ptr", float64(4.56)).Lte("float64ptr", float64(4.56)), - NewQueryBuilder("ptr").Gte("bytesptr", []byte("bytes")).Lte("bytesptr", []byte("bytes")), - NewQueryBuilder("ptr").Gte("stringptr", "string").Lte("stringptr", "string"), - NewQueryBuilder("ptr").Gte("boolptr", false), - NewQueryBuilder("ptr").Lte("boolptr", true), - NewQueryBuilder("ptr").Gte("timeptr", time.Now().Add(-time.Hour*24)).Lte("timeptr", time.Now().Add(time.Hour*24)), + NewQueryBuilder("ptr", driver.Adapter).Gte("intptr", 1).Lte("intptr", 1), + NewQueryBuilder("ptr", driver.Adapter).Gte("int8ptr", int8(2)).Lte("int8ptr", int8(2)), + NewQueryBuilder("ptr", driver.Adapter).Gte("int16ptr", int16(3)).Lte("int16ptr", int16(3)), + NewQueryBuilder("ptr", driver.Adapter).Gte("int32ptr", int32(4)).Lte("int32ptr", int32(4)), + NewQueryBuilder("ptr", driver.Adapter).Gte("int64ptr", int64(5)).Lte("int64ptr", int64(5)), + NewQueryBuilder("ptr", driver.Adapter).Gte("uintptr", uint(6)).Lte("uintptr", uint(6)), + NewQueryBuilder("ptr", driver.Adapter).Gte("uint8ptr", uint8(7)).Lte("uint8ptr", uint8(7)), + NewQueryBuilder("ptr", driver.Adapter).Gte("uint16ptr", uint16(8)).Lte("uint16ptr", uint16(8)), + NewQueryBuilder("ptr", driver.Adapter).Gte("uint32ptr", uint32(9)).Lte("uint32ptr", uint32(9)), + NewQueryBuilder("ptr", driver.Adapter).Gte("uint64ptr", uint64(10)).Lte("uint64ptr", uint64(10)), + NewQueryBuilder("ptr", driver.Adapter).Gte("float32ptr", float32(1.23)).Lte("float32ptr", float32(1.23)), + NewQueryBuilder("ptr", driver.Adapter).Gte("float64ptr", float64(4.56)).Lte("float64ptr", float64(4.56)), + NewQueryBuilder("ptr", driver.Adapter).Gte("bytesptr", []byte("bytes")).Lte("bytesptr", []byte("bytes")), + NewQueryBuilder("ptr", driver.Adapter).Gte("stringptr", "string").Lte("stringptr", "string"), + NewQueryBuilder("ptr", driver.Adapter).Gte("boolptr", false), + NewQueryBuilder("ptr", driver.Adapter).Lte("boolptr", true), + NewQueryBuilder("ptr", driver.Adapter).Gte("timeptr", time.Now().Add(-time.Hour*24)).Lte("timeptr", time.Now().Add(time.Hour*24)), } for _, builder := range builders { @@ -416,23 +416,23 @@ func TestPtrType(t *testing.T) { }) t.Run("LT AND GT", func(t *testing.T) { builders := []*QueryBuilder{ - NewQueryBuilder("ptr").Gt("intptr", 0).Lt("intptr", 2), - NewQueryBuilder("ptr").Gt("int8ptr", int8(1)).Lt("int8ptr", int8(3)), - NewQueryBuilder("ptr").Gt("int16ptr", int16(2)).Lt("int16ptr", int16(4)), - NewQueryBuilder("ptr").Gt("int32ptr", int32(3)).Lt("int32ptr", int32(5)), - NewQueryBuilder("ptr").Gt("int64ptr", int64(4)).Lt("int64ptr", int64(6)), - NewQueryBuilder("ptr").Gt("uintptr", uint(5)).Lt("uintptr", uint(7)), - NewQueryBuilder("ptr").Gt("uint8ptr", uint8(6)).Lt("uint8ptr", uint8(8)), - NewQueryBuilder("ptr").Gt("uint16ptr", uint16(7)).Lt("uint16ptr", uint16(9)), - NewQueryBuilder("ptr").Gt("uint32ptr", uint32(8)).Lt("uint32ptr", uint32(10)), - NewQueryBuilder("ptr").Gt("uint64ptr", uint64(9)).Lt("uint64ptr", uint64(11)), - NewQueryBuilder("ptr").Gt("float32ptr", float32(1.22)).Lt("float32ptr", float32(1.24)), - NewQueryBuilder("ptr").Gt("float64ptr", float64(4.55)).Lt("float64ptr", float64(4.57)), - NewQueryBuilder("ptr").Gt("bytesptr", []byte("byte")).Lt("bytesptr", []byte("bytess")), - NewQueryBuilder("ptr").Gt("stringptr", "strin").Lt("stringptr", "strings"), - NewQueryBuilder("ptr").Gt("boolptr", false), - NewQueryBuilder("ptr").Lt("boolptr", true), - NewQueryBuilder("ptr").Gt("timeptr", time.Now().Add(-time.Hour*24)).Lt("timeptr", time.Now().Add(time.Hour*24)), + NewQueryBuilder("ptr", driver.Adapter).Gt("intptr", 0).Lt("intptr", 2), + NewQueryBuilder("ptr", driver.Adapter).Gt("int8ptr", int8(1)).Lt("int8ptr", int8(3)), + NewQueryBuilder("ptr", driver.Adapter).Gt("int16ptr", int16(2)).Lt("int16ptr", int16(4)), + NewQueryBuilder("ptr", driver.Adapter).Gt("int32ptr", int32(3)).Lt("int32ptr", int32(5)), + NewQueryBuilder("ptr", driver.Adapter).Gt("int64ptr", int64(4)).Lt("int64ptr", int64(6)), + NewQueryBuilder("ptr", driver.Adapter).Gt("uintptr", uint(5)).Lt("uintptr", uint(7)), + NewQueryBuilder("ptr", driver.Adapter).Gt("uint8ptr", uint8(6)).Lt("uint8ptr", uint8(8)), + NewQueryBuilder("ptr", driver.Adapter).Gt("uint16ptr", uint16(7)).Lt("uint16ptr", uint16(9)), + NewQueryBuilder("ptr", driver.Adapter).Gt("uint32ptr", uint32(8)).Lt("uint32ptr", uint32(10)), + NewQueryBuilder("ptr", driver.Adapter).Gt("uint64ptr", uint64(9)).Lt("uint64ptr", uint64(11)), + NewQueryBuilder("ptr", driver.Adapter).Gt("float32ptr", float32(1.22)).Lt("float32ptr", float32(1.24)), + NewQueryBuilder("ptr", driver.Adapter).Gt("float64ptr", float64(4.55)).Lt("float64ptr", float64(4.57)), + NewQueryBuilder("ptr", driver.Adapter).Gt("bytesptr", []byte("byte")).Lt("bytesptr", []byte("bytess")), + NewQueryBuilder("ptr", driver.Adapter).Gt("stringptr", "strin").Lt("stringptr", "strings"), + NewQueryBuilder("ptr", driver.Adapter).Gt("boolptr", false), + NewQueryBuilder("ptr", driver.Adapter).Lt("boolptr", true), + NewQueryBuilder("ptr", driver.Adapter).Gt("timeptr", time.Now().Add(-time.Hour*24)).Lt("timeptr", time.Now().Add(time.Hour*24)), } for _, builder := range builders { @@ -450,7 +450,7 @@ func TestPtrType(t *testing.T) { } func TestFindByPrimaryKeyCaseDatabaseRecordIsEmpty(t *testing.T) { - flc := NewFirstLevelCache(emptyType()) + flc := NewFirstLevelCache(emptyType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) var empty Empty NoError(t, flc.FindByPrimaryKey(NewUint64Value(uint64(1)), &empty)) @@ -460,9 +460,9 @@ func TestFindByPrimaryKeyCaseDatabaseRecordIsEmpty(t *testing.T) { } func TestFindByQueryBuilderCaseDatabaseRecordIsEmpty(t *testing.T) { - flc := NewFirstLevelCache(emptyType()) + flc := NewFirstLevelCache(emptyType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) - builder := NewQueryBuilder("empties").In("id", []uint64{1}) + builder := NewQueryBuilder("empties", driver.Adapter).In("id", []uint64{1}) var empties EmptySlice NoError(t, flc.FindByQueryBuilder(builder, &empties)) if len(empties) != 0 { @@ -471,9 +471,9 @@ func TestFindByQueryBuilderCaseDatabaseRecordIsEmpty(t *testing.T) { } func TestCountByQueryBuilderCaseDatabaseRecordIsEmptyFLC(t *testing.T) { - flc := NewFirstLevelCache(emptyType()) + flc := NewFirstLevelCache(emptyType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) - builder := NewQueryBuilder("empties").Eq("id", uint64(1)) + builder := NewQueryBuilder("empties", driver.Adapter).Eq("id", uint64(1)) count, err := flc.CountByQueryBuilder(builder) NoError(t, err) if count != 0 { @@ -482,7 +482,7 @@ func TestCountByQueryBuilderCaseDatabaseRecordIsEmptyFLC(t *testing.T) { } func TestFindAllCaseDatabaseRecordIsEmpty(t *testing.T) { - flc := NewFirstLevelCache(emptyType()) + flc := NewFirstLevelCache(emptyType(), database.NewAdapterWithDBType(driver.DBType)) NoError(t, flc.WarmUp(conn)) var empties EmptySlice NoError(t, flc.FindAll(&empties)) @@ -527,7 +527,7 @@ func BenchmarkPK_MySQL(b *testing.B) { } func BenchmarkPK_Rapidash(b *testing.B) { - flc := NewFirstLevelCache(eventType()) + flc := NewFirstLevelCache(eventType(), database.NewAdapterWithDBType(driver.DBType)) if err := flc.WarmUp(conn); err != nil { panic(err) } @@ -591,11 +591,11 @@ func BenchmarkIN_MySQL(b *testing.B) { } func BenchmarkIN_Rapidash(b *testing.B) { - flc := NewFirstLevelCache(eventType()) + flc := NewFirstLevelCache(eventType(), database.NewAdapterWithDBType(driver.DBType)) if err := flc.WarmUp(conn); err != nil { panic(err) } - builder := NewQueryBuilder("events").In("id", []uint64{1, 2, 3, 4, 5}) + builder := NewQueryBuilder("events", driver.Adapter).In("id", []uint64{1, 2, 3, 4, 5}) b.ResetTimer() events := []*Event{} for n := 0; n < b.N; n++ { diff --git a/go.mod b/go.mod index ddec99f..fad5f9b 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/juju/testing v0.0.0-20190418112600-6570bd8f8541 // indirect github.com/knocknote/msgpack v0.0.0-20200324090114-032ce7b219cd github.com/knocknote/vitess-sqlparser v0.0.0-20181121014348-1003c43917a3 + github.com/lib/pq v1.1.1 github.com/rakyll/statik v0.1.6 github.com/rs/xid v0.0.0-20180316063648-705291fb2231 github.com/rs/zerolog v1.13.0 diff --git a/option.go b/option.go index 217c3bb..c8a1f00 100644 --- a/option.go +++ b/option.go @@ -2,6 +2,8 @@ package rapidash import ( "time" + + "go.knocknote.io/rapidash/database" ) type OptionFunc func(*Rapidash) @@ -200,6 +202,12 @@ func LastLevelCacheTagLockExpiration(tag string, expiration time.Duration) Optio } } +func DatabaseAdapter(dbType database.DBType) OptionFunc { + return func(r *Rapidash) { + r.opt.adapter = database.NewAdapterWithDBType(dbType) + } +} + func LastLevelCacheTagOptimisticLock(tag string, enabled bool) OptionFunc { return func(r *Rapidash) { opt := r.opt.llcOpt.tagOpt[tag] diff --git a/query_builder.go b/query_builder.go index 473ea9c..f0ebc5e 100644 --- a/query_builder.go +++ b/query_builder.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/knocknote/vitess-sqlparser/sqlparser" + "go.knocknote.io/rapidash/database" "go.knocknote.io/rapidash/server" "golang.org/x/xerrors" ) @@ -246,6 +247,7 @@ func NewValueIterator(keys []server.CacheKey) *ValueIterator { type Queries struct { tableName string + queryHelper *database.QueryHelper primaryIndex *Index queries []*Query cacheMissQueries []*Query @@ -255,9 +257,10 @@ type Queries struct { isAllSQL bool } -func NewQueries(tableName string, primaryIndex *Index, queryNum int) *Queries { +func NewQueries(tableName string, primaryIndex *Index, queryNum int, queryHelper *database.QueryHelper) *Queries { return &Queries{ tableName: tableName, + queryHelper: queryHelper, primaryIndex: primaryIndex, queries: make([]*Query, 0, queryNum), cacheMissQueries: []*Query{}, @@ -370,20 +373,21 @@ func (q *Queries) FindCacheMissQueryByStructValue(value *StructValue) *Query { } func (q *Queries) CacheMissQueriesToSQL(typ *Struct) (string, []interface{}) { + q.queryHelper.ClearCount() escapedColumns := []string{} for _, column := range typ.Columns() { - escapedColumns = append(escapedColumns, fmt.Sprintf("`%s`", column)) + escapedColumns = append(escapedColumns, q.queryHelper.Quote(column)) } if q.rawSQL != "" { - return fmt.Sprintf("SELECT %s FROM `%s` %s", + return fmt.Sprintf("SELECT %s FROM %s %s", strings.Join(escapedColumns, ","), - q.tableName, + q.queryHelper.Quote(q.tableName), q.rawSQL, ), q.rawSQLValues } else if q.isAllSQL { - return fmt.Sprintf("SELECT %s FROM `%s`", + return fmt.Sprintf("SELECT %s FROM %s", strings.Join(escapedColumns, ","), - q.tableName, + q.queryHelper.Quote(q.tableName), ), nil } if len(q.cacheMissQueries) == 0 { @@ -418,15 +422,15 @@ func (q *Queries) CacheMissQueriesToSQL(typ *Struct) (string, []interface{}) { } else { queryArgs = append(queryArgs, v.RawValue()) } - placeholders = append(placeholders, "?") + placeholders = append(placeholders, q.queryHelper.Placeholder()) } - condition = fmt.Sprintf("`%s` IN (%s)", column, strings.Join(placeholders, ",")) + condition = fmt.Sprintf("%s IN (%s)", q.queryHelper.Quote(column), strings.Join(placeholders, ",")) } else { if !value.IsNil { queryArgs = append(queryArgs, value.RawValue()) - condition = fmt.Sprintf("`%s` = ?", column) + condition = fmt.Sprintf("%s = %s", q.queryHelper.Quote(column), q.queryHelper.Placeholder()) } else { - condition = fmt.Sprintf("`%s` IS NULL", column) + condition = fmt.Sprintf("%s IS NULL", q.queryHelper.Quote(column)) } } @@ -436,9 +440,9 @@ func (q *Queries) CacheMissQueriesToSQL(typ *Struct) (string, []interface{}) { if lockOpt != "" { lockOpt = " " + lockOpt } - return fmt.Sprintf("SELECT %s FROM `%s` WHERE %s%s", + return fmt.Sprintf("SELECT %s FROM %s WHERE %s%s", strings.Join(escapedColumns, ","), - q.tableName, + q.queryHelper.Quote(q.tableName), strings.Join(conditions, " AND "), lockOpt, ), queryArgs @@ -529,6 +533,7 @@ func (b *QueryBuilder) AvailableIndex() bool { } type QueryBuilder struct { + queryHelper *database.QueryHelper tableName string conditions *Conditions inCondition *INCondition @@ -540,13 +545,14 @@ type QueryBuilder struct { cachedQueries *Queries } -func NewQueryBuilder(tableName string) *QueryBuilder { +func NewQueryBuilder(tableName string, adapter database.Adapter) *QueryBuilder { return &QueryBuilder{ tableName: tableName, conditions: &Conditions{ conditions: []Condition{}, }, orderConditions: []*OrderCondition{}, + queryHelper: adapter.QueryHelper(), } } @@ -601,6 +607,7 @@ func (b *QueryBuilder) indexes() []string { } func (b *QueryBuilder) SelectSQL(typ *Struct) (string, []interface{}) { + b.queryHelper.ClearCount() where := []string{} args := []interface{}{} for _, condition := range b.conditions.conditions { @@ -609,45 +616,47 @@ func (b *QueryBuilder) SelectSQL(typ *Struct) (string, []interface{}) { } escapedColumns := []string{} for _, column := range typ.Columns() { - escapedColumns = append(escapedColumns, fmt.Sprintf("`%s`", column)) + escapedColumns = append(escapedColumns, b.queryHelper.Quote(column)) } lockOpt := b.lockOpt.String() if lockOpt != "" { lockOpt = " " + lockOpt } - return fmt.Sprintf("SELECT %s FROM `%s` WHERE %s%s", + return fmt.Sprintf("SELECT %s FROM %s WHERE %s%s", strings.Join(escapedColumns, ","), - b.tableName, + b.queryHelper.Quote(b.tableName), strings.Join(where, " AND "), lockOpt, ), args } func (b *QueryBuilder) UpdateSQL(updateMap map[string]interface{}) (string, []interface{}) { + b.queryHelper.ClearCount() + setList := []string{} + values := []interface{}{} + for k, v := range updateMap { + setList = append(setList, fmt.Sprintf("%s = %s", b.queryHelper.Quote(k), b.queryHelper.Placeholder())) + values = append(values, v) + } where := []string{} args := []interface{}{} for _, condition := range b.conditions.conditions { where = append(where, condition.Query()) args = append(args, condition.QueryArgs()...) } - setList := []string{} - values := []interface{}{} - for k, v := range updateMap { - setList = append(setList, fmt.Sprintf("`%s` = ?", k)) - values = append(values, v) - } values = append(values, args...) - return fmt.Sprintf("UPDATE `%s` SET %s WHERE %s", b.tableName, strings.Join(setList, ","), strings.Join(where, " AND ")), values + return fmt.Sprintf("UPDATE %s SET %s WHERE %s", b.queryHelper.Quote(b.tableName), strings.Join(setList, ","), strings.Join(where, " AND ")), values } func (b *QueryBuilder) DeleteSQL() (string, []interface{}) { + b.queryHelper.ClearCount() where := []string{} args := []interface{}{} for _, condition := range b.conditions.conditions { where = append(where, condition.Query()) args = append(args, condition.QueryArgs()...) } - return fmt.Sprintf("DELETE FROM `%s` WHERE %s", b.tableName, strings.Join(where, " AND ")), args + return fmt.Sprintf("DELETE FROM %s WHERE %s", b.queryHelper.Quote(b.tableName), strings.Join(where, " AND ")), args } func (b *QueryBuilder) Release() { @@ -661,7 +670,7 @@ func (b *QueryBuilder) Build(factory *ValueFactory) { func (b *QueryBuilder) buildINQueryWithIndex(indexes map[string]*Index) (*Queries, error) { queryNum := len(b.inCondition.values) columnNum := len(b.conditions.conditions) - queries := NewQueries(b.tableName, b.primaryIndexFromIndexes(indexes), queryNum) + queries := NewQueries(b.tableName, b.primaryIndexFromIndexes(indexes), queryNum, b.queryHelper) for i := 0; i < queryNum; i++ { queries.Add(NewQuery(columnNum)) } @@ -699,26 +708,33 @@ func (b *QueryBuilder) buildINQueryWithIndex(indexes map[string]*Index) (*Querie func (b *QueryBuilder) buildAllQuery() *Queries { b.isIgnoreCache = true return &Queries{ - tableName: b.tableName, - isAllSQL: true, - queries: make([]*Query, 1), + tableName: b.tableName, + queryHelper: b.queryHelper, + isAllSQL: true, + queries: make([]*Query, 1), } } func (b *QueryBuilder) buildRawQuery() (*Queries, error) { - prefix := fmt.Sprintf("SELECT * FROM `%s` ", b.tableName) - stmt, err := sqlparser.Parse(prefix + b.sqlCondition.stmt) - if err != nil { - return nil, xerrors.Errorf("failed to parse %s: %w", prefix+b.sqlCondition.stmt, err) - } - selectStmt := stmt.(*sqlparser.Select) - if selectStmt.GroupBy != nil || - selectStmt.Having != nil || - selectStmt.OrderBy != nil { + // vitess-sqlparser supports only mysql type sql syntax. + if b.queryHelper.DBType() == database.MySQL { + prefix := fmt.Sprintf("SELECT * FROM %s ", b.queryHelper.Quote(b.tableName)) + stmt, err := sqlparser.Parse(prefix + b.sqlCondition.stmt) + if err != nil { + return nil, xerrors.Errorf("failed to parse %s: %w", prefix+b.sqlCondition.stmt, err) + } + selectStmt := stmt.(*sqlparser.Select) + if selectStmt.GroupBy != nil || + selectStmt.Having != nil || + selectStmt.OrderBy != nil { + b.isIgnoreCache = true + } + } else { b.isIgnoreCache = true } return &Queries{ tableName: b.tableName, + queryHelper: b.queryHelper, rawSQL: b.sqlCondition.stmt, rawSQLValues: b.sqlCondition.rawValues, queries: make([]*Query, 1), @@ -784,7 +800,7 @@ func (b *QueryBuilder) BuildWithIndex(factory *ValueFactory, indexes map[string] return queries, nil } columnNum := len(b.conditions.conditions) - queries := NewQueries(b.tableName, b.primaryIndexFromIndexes(indexes), 1) + queries := NewQueries(b.tableName, b.primaryIndexFromIndexes(indexes), 1, b.queryHelper) queries.lockOpt = b.lockOpt query := NewQuery(columnNum) for _, condition := range b.conditions.conditions { @@ -812,32 +828,32 @@ func (b *QueryBuilder) Query() string { } func (b *QueryBuilder) Eq(column string, value interface{}) *QueryBuilder { - b.conditions.Append(&EQCondition{column: column, rawValue: value}) + b.conditions.Append(&EQCondition{column: column, rawValue: value, queryHelper: b.queryHelper}) return b } func (b *QueryBuilder) Neq(column string, value interface{}) *QueryBuilder { - b.conditions.Append(&NEQCondition{column: column, rawValue: value}) + b.conditions.Append(&NEQCondition{column: column, rawValue: value, queryHelper: b.queryHelper}) return b } func (b *QueryBuilder) Gt(column string, value interface{}) *QueryBuilder { - b.conditions.Append(>Condition{column: column, rawValue: value}) + b.conditions.Append(>Condition{column: column, rawValue: value, queryHelper: b.queryHelper}) return b } func (b *QueryBuilder) Lt(column string, value interface{}) *QueryBuilder { - b.conditions.Append(<Condition{column: column, rawValue: value}) + b.conditions.Append(<Condition{column: column, rawValue: value, queryHelper: b.queryHelper}) return b } func (b *QueryBuilder) Gte(column string, value interface{}) *QueryBuilder { - b.conditions.Append(>ECondition{column: column, rawValue: value}) + b.conditions.Append(>ECondition{column: column, rawValue: value, queryHelper: b.queryHelper}) return b } func (b *QueryBuilder) Lte(column string, value interface{}) *QueryBuilder { - b.conditions.Append(<ECondition{column: column, rawValue: value}) + b.conditions.Append(<ECondition{column: column, rawValue: value, queryHelper: b.queryHelper}) return b } @@ -846,7 +862,7 @@ func (b *QueryBuilder) In(column string, values interface{}) *QueryBuilder { b.err = ErrMultipleINQueries return b } - condition := &INCondition{column: column, rawValues: values} + condition := &INCondition{column: column, rawValues: values, queryHelper: b.queryHelper} b.inCondition = condition b.conditions.Append(condition) return b @@ -937,9 +953,10 @@ func (b *QueryBuilder) IsUnsupportedCacheQuery() bool { } type EQCondition struct { - column string - rawValue interface{} - value *Value + queryHelper *database.QueryHelper + column string + rawValue interface{} + value *Value } func (c *EQCondition) Column() string { @@ -964,9 +981,9 @@ func (c *EQCondition) Search(tree *BTree) []Leaf { func (c *EQCondition) Query() string { if c.rawValue == nil { - return fmt.Sprintf("`%s` IS NULL", c.column) + return fmt.Sprintf("%s IS NULL", c.queryHelper.Quote(c.column)) } - return fmt.Sprintf("`%s` = ?", c.column) + return fmt.Sprintf("%s = %s", c.queryHelper.Quote(c.column), c.queryHelper.Placeholder()) } func (c *EQCondition) QueryArgs() []interface{} { @@ -992,9 +1009,10 @@ func (c *EQCondition) Release() { } type NEQCondition struct { - column string - rawValue interface{} - value *Value + queryHelper *database.QueryHelper + column string + rawValue interface{} + value *Value } func (c *NEQCondition) Column() string { @@ -1007,9 +1025,9 @@ func (c *NEQCondition) Value() *Value { func (c *NEQCondition) Query() string { if c.rawValue == nil { - return fmt.Sprintf("`%s` IS NOT NULL", c.column) + return fmt.Sprintf("%s IS NOT NULL", c.queryHelper.Quote(c.column)) } - return fmt.Sprintf("`%s` != ?", c.column) + return fmt.Sprintf("%s != %s", c.queryHelper.Quote(c.column), c.queryHelper.Placeholder()) } func (c *NEQCondition) QueryArgs() []interface{} { @@ -1044,9 +1062,10 @@ func (c *NEQCondition) Release() { } type GTCondition struct { - column string - rawValue interface{} - value *Value + queryHelper *database.QueryHelper + column string + rawValue interface{} + value *Value } func (c *GTCondition) Column() string { @@ -1058,7 +1077,7 @@ func (c *GTCondition) Value() *Value { } func (c *GTCondition) Query() string { - return fmt.Sprintf("`%s` > ?", c.column) + return fmt.Sprintf("%s > %s", c.queryHelper.Quote(c.column), c.queryHelper.Placeholder()) } func (c *GTCondition) QueryArgs() []interface{} { @@ -1089,9 +1108,10 @@ func (c *GTCondition) Release() { } type GTECondition struct { - column string - rawValue interface{} - value *Value + queryHelper *database.QueryHelper + column string + rawValue interface{} + value *Value } func (c *GTECondition) Column() string { @@ -1103,7 +1123,7 @@ func (c *GTECondition) Value() *Value { } func (c *GTECondition) Query() string { - return fmt.Sprintf("`%s` >= ?", c.column) + return fmt.Sprintf("%s >= %s", c.queryHelper.Quote(c.column), c.queryHelper.Placeholder()) } func (c *GTECondition) QueryArgs() []interface{} { @@ -1134,9 +1154,10 @@ func (c *GTECondition) Release() { } type LTCondition struct { - column string - rawValue interface{} - value *Value + queryHelper *database.QueryHelper + column string + rawValue interface{} + value *Value } func (c *LTCondition) Column() string { @@ -1148,7 +1169,7 @@ func (c *LTCondition) Value() *Value { } func (c *LTCondition) Query() string { - return fmt.Sprintf("`%s` < ?", c.column) + return fmt.Sprintf("%s < %s", c.queryHelper.Quote(c.column), c.queryHelper.Placeholder()) } func (c *LTCondition) QueryArgs() []interface{} { @@ -1179,9 +1200,10 @@ func (c *LTCondition) Release() { } type LTECondition struct { - column string - rawValue interface{} - value *Value + queryHelper *database.QueryHelper + column string + rawValue interface{} + value *Value } func (c *LTECondition) Column() string { @@ -1193,7 +1215,7 @@ func (c *LTECondition) Value() *Value { } func (c *LTECondition) Query() string { - return fmt.Sprintf("`%s` <= ?", c.column) + return fmt.Sprintf("%s <= %s", c.queryHelper.Quote(c.column), c.queryHelper.Placeholder()) } func (c *LTECondition) QueryArgs() []interface{} { @@ -1224,9 +1246,10 @@ func (c *LTECondition) Release() { } type INCondition struct { - column string - rawValues interface{} - values []*Value + queryHelper *database.QueryHelper + column string + rawValues interface{} + values []*Value } func (c *INCondition) Column() string { @@ -1238,11 +1261,7 @@ func (c *INCondition) Value() *Value { } func (c *INCondition) Query() string { - placeholders := make([]string, len(c.values)) - for i := 0; i < len(c.values); i++ { - placeholders[i] = "?" - } - return fmt.Sprintf("`%s` IN (%s)", c.column, strings.Join(placeholders, ",")) + return fmt.Sprintf("%s IN (%s)", c.queryHelper.Quote(c.column), c.queryHelper.Placeholders(len(c.values))) } func (c *INCondition) QueryArgs() []interface{} { diff --git a/rapidash.go b/rapidash.go index ded038a..c76b6e8 100644 --- a/rapidash.go +++ b/rapidash.go @@ -9,6 +9,7 @@ import ( "time" "github.com/rs/xid" + "go.knocknote.io/rapidash/database" "go.knocknote.io/rapidash/server" "golang.org/x/xerrors" ) @@ -139,6 +140,7 @@ type QueryLog struct { } type Option struct { + adapter database.Adapter serverType CacheServerType serverAddrs []string timeout time.Duration @@ -164,6 +166,7 @@ type Option struct { func defaultOption() Option { return Option{ + adapter: database.NewAdapter(), serverType: CacheServerTypeMemcached, timeout: DefaultTimeout, maxIdleConnections: DefaultMaxIdleConns, @@ -187,6 +190,7 @@ func defaultOption() Option { type Connection interface { QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row ExecContext(context.Context, string, ...interface{}) (sql.Result, error) } @@ -846,7 +850,7 @@ func (r *Rapidash) WarmUp(conn *sql.DB, typ *Struct, isReadOnly bool) error { } func (r *Rapidash) WarmUpFirstLevelCache(conn *sql.DB, typ *Struct) error { - flc := NewFirstLevelCache(typ) + flc := NewFirstLevelCache(typ, r.opt.adapter) if err := flc.WarmUp(conn); err != nil { return xerrors.Errorf("cannot warm up FirstLevelCache. table is %s: %w", typ.tableName, err) } @@ -872,7 +876,7 @@ func (r *Rapidash) tableOption(tableName string) TableOption { } func (r *Rapidash) WarmUpSecondLevelCache(conn *sql.DB, typ *Struct) error { - slc := NewSecondLevelCache(typ, r.cacheServer, r.tableOption(typ.tableName)) + slc := NewSecondLevelCache(typ, r.cacheServer, r.tableOption(typ.tableName), r.opt.adapter) if err := slc.WarmUp(conn); err != nil { return xerrors.Errorf("cannot warm up SecondLevelCache. table is %s: %w", typ.tableName, err) } @@ -1028,5 +1032,9 @@ func New(opts ...OptionFunc) (*Rapidash, error) { return nil, xerrors.Errorf("failed to set server: %w", err) } r.setLogger() + + if r.opt.adapter == nil { + return nil, xerrors.Errorf("adapter not setup. must pass database adapter option") + } return r, nil } diff --git a/rapidash_test.go b/rapidash_test.go index 5c3f451..d9e66d0 100644 --- a/rapidash_test.go +++ b/rapidash_test.go @@ -11,7 +11,7 @@ import ( func TestServerChanging(t *testing.T) { t.Run("remove and add server", func(t *testing.T) { - cache, err := New(ServerAddrs([]string{"localhost:11211"}), MaxIdleConnections(1000), Timeout(200*time.Millisecond), LastLevelCachePessimisticLock(true)) + cache, err := New(ServerAddrs([]string{"localhost:11211"}), MaxIdleConnections(1000), Timeout(200*time.Millisecond), LastLevelCachePessimisticLock(true), DatabaseAdapter(driver.DBType)) NoError(t, err) tx, err := cache.Begin() NoErrorf(t, err, "cannot begin cache transaction") @@ -24,14 +24,14 @@ func TestServerChanging(t *testing.T) { }) t.Run("remove and add only slc server", func(t *testing.T) { - cache, err := New(ServerAddrs([]string{"localhost:11211"}), MaxIdleConnections(1000), Timeout(200*time.Millisecond), LastLevelCachePessimisticLock(true)) + cache, err := New(ServerAddrs([]string{"localhost:11211"}), MaxIdleConnections(1000), Timeout(200*time.Millisecond), LastLevelCachePessimisticLock(true), DatabaseAdapter(driver.DBType)) NoError(t, err) NoError(t, cache.WarmUp(conn, userLoginType(), false)) tx, err := cache.Begin(conn) NoErrorf(t, err, "cannot begin cache transaction") NoErrorf(t, cache.RemoveSecondLevelCacheServers("localhost:11211"), "cannot remove server") - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLogin Errorf(t, tx.FindByQueryBuilder(builder, &v), "find slc cache") NoErrorf(t, tx.Create("int", Int(1)), "cannot create cache") @@ -43,7 +43,7 @@ func TestServerChanging(t *testing.T) { }) t.Run("remove and add only llc server", func(t *testing.T) { - cache, err := New(ServerAddrs([]string{"localhost:11211"}), MaxIdleConnections(1000), Timeout(200000000000), LastLevelCachePessimisticLock(true)) + cache, err := New(ServerAddrs([]string{"localhost:11211"}), MaxIdleConnections(1000), Timeout(200000000000), LastLevelCachePessimisticLock(true), DatabaseAdapter(driver.DBType)) NoError(t, err) tx, err := cache.Begin() NoErrorf(t, err, "cannot begin cache transaction") @@ -59,7 +59,7 @@ func TestServerChanging(t *testing.T) { func TestRecover(t *testing.T) { txConn, err := conn.Begin() NoError(t, err) - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLogin { tx, err := cache.Begin(txConn) diff --git a/second_level_cache.go b/second_level_cache.go index eed162b..a5ed184 100644 --- a/second_level_cache.go +++ b/second_level_cache.go @@ -9,8 +9,9 @@ import ( "sync" "time" - "github.com/knocknote/vitess-sqlparser/sqlparser" "github.com/knocknote/msgpack" + "github.com/knocknote/vitess-sqlparser/sqlparser" + "go.knocknote.io/rapidash/database" "go.knocknote.io/rapidash/server" "golang.org/x/xerrors" ) @@ -45,6 +46,7 @@ type SecondLevelCache struct { valueDecoderPool sync.Pool primaryKeyDecoderPool sync.Pool valueFactory *ValueFactory + adapter database.Adapter } type TxValue struct { @@ -91,7 +93,7 @@ func (v *TxValue) EncodeLog() string { return v.String() } -func NewSecondLevelCache(s *Struct, server server.CacheServer, opt TableOption) *SecondLevelCache { +func NewSecondLevelCache(s *Struct, server server.CacheServer, opt TableOption, adapter database.Adapter) *SecondLevelCache { valueFactory := NewValueFactory() return &SecondLevelCache{ typ: s, @@ -110,6 +112,7 @@ func NewSecondLevelCache(s *Struct, server server.CacheServer, opt TableOption) }, }, valueFactory: valueFactory, + adapter: adapter, } } @@ -152,12 +155,9 @@ func (c *SecondLevelCache) WarmUp(conn *sql.DB) error { } func (c *SecondLevelCache) showCreateTable(conn *sql.DB) (string, error) { - var ( - tbl string - ddl string - ) - if err := conn.QueryRow(fmt.Sprintf("SHOW CREATE TABLE `%s`", c.typ.tableName)).Scan(&tbl, &ddl); err != nil { - return "", xerrors.Errorf("failed to execute 'SHOW CREATE TABLE `%s`': %w", c.typ.tableName, err) + ddl, err := c.adapter.TableDDL(conn, c.typ.tableName) + if err != nil { + return "", xerrors.Errorf("failed to get ddl for %s: %w", c.typ.tableName) } return ddl, nil } @@ -176,7 +176,7 @@ func (c *SecondLevelCache) setupPrimaryKey(constraint *sqlparser.Constraint) { } primaryKey := strings.Join(columns, ":") for idx := range columns { - subColumns := columns[:idx+1:idx+1] + subColumns := columns[: idx+1 : idx+1] if len(subColumns) == 0 { continue } @@ -1166,22 +1166,27 @@ func (c *SecondLevelCache) createByQueryWithValues(tx *Tx, query *Query, values } func (c *SecondLevelCache) insertSQL(value *StructValue) (string, []interface{}) { + qh := c.adapter.QueryHelper() escapedColumns := []string{} placeholders := []string{} values := []interface{}{} for _, column := range value.typ.Columns() { - escapedColumns = append(escapedColumns, fmt.Sprintf("`%s`", column)) - placeholders = append(placeholders, "?") - if value.fields[column] == nil { - values = append(values, nil) - } else { + if value.fields[column] != nil { + escapedColumns = append(escapedColumns, qh.Quote(column)) + placeholders = append(placeholders, qh.Placeholder()) values = append(values, value.fields[column].RawValue()) } } - return fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", - c.typ.tableName, + + var returningPhrase string + if !qh.SupportLastInsertID() { + returningPhrase = fmt.Sprintf("RETURNING %s", qh.Quote("id")) + } + return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) %s", + qh.Quote(c.typ.tableName), strings.Join(escapedColumns, ","), strings.Join(placeholders, ","), + returningPhrase, ), values } @@ -1192,26 +1197,12 @@ func (c *SecondLevelCache) Create(ctx context.Context, tx *Tx, marshaler Marshal return } defer value.Release() - sql, values := c.insertSQL(value) - result, err := tx.conn.ExecContext(ctx, sql, values...) + lastInsertID, err := c.insertIntoDB(ctx, tx, value) if err != nil { - e = xerrors.Errorf("failed sql %s %v: %w", sql, values, err) - return - } - lastInsertID, err := result.LastInsertId() - if err != nil { - e = xerrors.Errorf("failed to get last_insert_id(): %w", err) + e = xerrors.Errorf("failed to insert into db: %w", err) return } id = lastInsertID - for _, column := range c.primaryKey.Columns { - if value.fields[column] == nil { - // if value for primary key is not defined, - // rapidash assume that result.LastInsertId() can use alternatively. - value.fields[column] = c.valueFactory.CreateInt64Value(lastInsertID) - } - } - log.InsertIntoDB(tx.id, sql, values, value) if err := c.deleteKeyByValue(tx, value); err != nil { e = xerrors.Errorf("failed to delete key by value: %w", err) return @@ -1226,23 +1217,41 @@ func (c *SecondLevelCache) CreateWithoutCache(ctx context.Context, tx *Tx, marsh return } defer value.Release() - sql, values := c.insertSQL(value) - result, err := tx.conn.ExecContext(ctx, sql, values...) + lastInsertID, err := c.insertIntoDB(ctx, tx, value) if err != nil { - e = xerrors.Errorf("failed sql %s %v: %w", sql, values, err) + e = xerrors.Errorf("failed to insert into db: %w", err) return } - lastInsertID, err := result.LastInsertId() - if err != nil { - e = xerrors.Errorf("failed to get last_insert_id(): %w", err) - return + return lastInsertID, nil +} + +func (c *SecondLevelCache) insertIntoDB(ctx context.Context, tx *Tx, value *StructValue) (id int64, e error) { + sql, values := c.insertSQL(value) + // postgres does not support last_insert_id(). postgres supports RETURNING * instead of this. + if c.adapter.SupportLastInsertID() { + result, err := tx.conn.ExecContext(ctx, sql, values...) + if err != nil { + e = xerrors.Errorf("failed sql %s %v: %w", sql, values, err) + return + } + lastInsertID, err := result.LastInsertId() + if err != nil { + e = xerrors.Errorf("failed to get last_insert_id(): %w", err) + return + } + id = lastInsertID + } else { + if err := tx.conn.QueryRowContext(ctx, sql, values...).Scan(&id); err != nil { + e = xerrors.Errorf("failed to scan value: %w", err) + return + } } - id = lastInsertID + for _, column := range c.primaryKey.Columns { if value.fields[column] == nil { // if value for primary key is not defined, // rapidash assume that result.LastInsertId() can use alternatively. - value.fields[column] = c.valueFactory.CreateInt64Value(lastInsertID) + value.fields[column] = c.valueFactory.CreateInt64Value(id) } } log.InsertIntoDB(tx.id, sql, values, value) @@ -1364,7 +1373,7 @@ func (c *SecondLevelCache) DeleteByQueryBuilder(ctx context.Context, tx *Tx, bui } func (c *SecondLevelCache) builderByValue(value *StructValue, index *Index) *QueryBuilder { - builder := NewQueryBuilder(c.typ.tableName) + builder := NewQueryBuilder(c.typ.tableName, c.adapter) for _, column := range index.Columns { if value.fields[column] == nil { return nil @@ -1377,13 +1386,13 @@ func (c *SecondLevelCache) builderByValue(value *StructValue, index *Index) *Que func (c *SecondLevelCache) updateBuilderByValue(value *StructValue, index *Index, updateMap map[string]interface{}) *QueryBuilder { switch index.Type { case IndexTypePrimaryKey: - builder := NewQueryBuilder(c.typ.tableName) + builder := NewQueryBuilder(c.typ.tableName, c.adapter) for _, column := range index.Columns { builder.Eq(column, value.fields[column].RawValue()) } return builder case IndexTypeKey, IndexTypeUniqueKey: - builder := NewQueryBuilder(c.typ.tableName) + builder := NewQueryBuilder(c.typ.tableName, c.adapter) for _, column := range index.Columns { if _, exists := updateMap[column]; !exists { return nil diff --git a/second_level_cache_test.go b/second_level_cache_test.go index 31dc324..8dd7aaf 100644 --- a/second_level_cache_test.go +++ b/second_level_cache_test.go @@ -1,13 +1,18 @@ package rapidash import ( + "bufio" "bytes" "context" "fmt" "net" + "os" + "path/filepath" "testing" "time" + "go.knocknote.io/rapidash/database" + "go.knocknote.io/rapidash/server" "golang.org/x/xerrors" ) @@ -223,7 +228,7 @@ func TestSimpleRead(t *testing.T) { func testSimpleRead(t *testing.T, typ CacheServerType) { NoError(t, initCache(conn, typ)) userLogin := defaultUserLogin() - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) @@ -233,7 +238,7 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { tx, err := cache.Begin(txConn) NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) @@ -249,7 +254,7 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() t.Run("from cache server", func(t *testing.T) { - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) @@ -257,7 +262,7 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { Equal(t, v.Name, userLogin.Name) }) t.Run("from stash", func(t *testing.T) { - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) @@ -272,7 +277,7 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { tx, err := cache.Begin(txConn) NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) @@ -280,12 +285,24 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { Equal(t, v.Name, userLogin.Name) NoError(t, tx.Commit()) + + f, err := os.Open(filepath.Join("testdata", driver.Name, "alter_user_logins.sql")) + NoError(t, err) + defer f.Close() + queryScanner := bufio.NewScanner(f) t.Run("ADD COLUMN", func(t *testing.T) { + { + txConn, err := conn.Begin() + NoError(t, err) + queryScanner.Scan() + if _, err := txConn.Exec(queryScanner.Text()); err != nil { + NoError(t, txConn.Rollback()) + t.Fatalf("%+v", err) + } + NoError(t, txConn.Commit()) + } txConn, err := conn.Begin() NoError(t, err) - if _, err := txConn.Exec("ALTER TABLE user_logins ADD password varchar(10) DEFAULT '100'"); err != nil { - t.Fatalf("%+v", err) - } NoError(t, cache.WarmUpSecondLevelCache(conn, userLoginType().FieldString("password"))) tx, err := cache.Begin(txConn) NoError(t, err) @@ -293,7 +310,7 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLoginAfterAddColumn NoError(t, tx.FindByQueryBuilder(builder, &v)) @@ -303,11 +320,18 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { NoError(t, tx.Commit()) }) t.Run("MODIFY COLUMN", func(t *testing.T) { + { + txConn, err := conn.Begin() + NoError(t, err) + queryScanner.Scan() + if _, err := txConn.Exec(queryScanner.Text()); err != nil { + NoError(t, txConn.Rollback()) + t.Fatalf("%+v", err) + } + NoError(t, txConn.Commit()) + } txConn, err := conn.Begin() NoError(t, err) - if _, err := txConn.Exec("ALTER TABLE user_logins MODIFY COLUMN password int(20) unsigned"); err != nil { - t.Fatalf("%+v", err) - } NoError(t, cache.WarmUpSecondLevelCache(conn, userLoginType().FieldUint64("password"))) tx, err := cache.Begin(txConn) NoError(t, err) @@ -315,7 +339,7 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLoginReTyped NoError(t, tx.FindByQueryBuilder(builder, &v)) @@ -325,11 +349,18 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { NoError(t, tx.Commit()) }) t.Run("DROP COLUMN", func(t *testing.T) { + { + txConn, err := conn.Begin() + NoError(t, err) + queryScanner.Scan() + if _, err := txConn.Exec(queryScanner.Text()); err != nil { + NoError(t, txConn.Rollback()) + t.Fatalf("%+v", err) + } + NoError(t, txConn.Commit()) + } txConn, err := conn.Begin() NoError(t, err) - if _, err := txConn.Exec("ALTER TABLE user_logins DROP COLUMN password"); err != nil { - t.Fatalf("%+v", err) - } NoError(t, cache.WarmUpSecondLevelCache(conn, userLoginType())) tx, err := cache.Begin(txConn) @@ -338,7 +369,7 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLogin NoError(t, tx.FindByQueryBuilder(builder, &v)) @@ -354,7 +385,7 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { tx, err := cache.Begin(txConn) NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logins").Eq("id", uint64(10000)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(10000)) var v UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) @@ -369,14 +400,14 @@ func testSimpleRead(t *testing.T, typ CacheServerType) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() t.Run("from cache server", func(t *testing.T) { - builder := NewQueryBuilder("user_logins").Eq("id", uint64(10000)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(10000)) var v UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) Equal(t, v.ID, uint64(0)) }) t.Run("from stash", func(t *testing.T) { - builder := NewQueryBuilder("user_logins").Eq("id", uint64(10000)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(10000)) var v UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) @@ -402,7 +433,7 @@ func testSimpleReadWithPessimisticLock(t *testing.T, typ CacheServerType) { pessimisticLock: &pessimisticLock, lockExpiration: &lockExpiration, expiration: &expiration, - }) + }, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) @@ -410,7 +441,7 @@ func testSimpleReadWithPessimisticLock(t *testing.T, typ CacheServerType) { NoError(t, err) tx, err := cache.Begin(txConn) NoError(t, err) - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) if v.ID != userLogin.ID { @@ -422,7 +453,7 @@ func testSimpleReadWithPessimisticLock(t *testing.T, typ CacheServerType) { t.Run("find locked value by another tx", func(t *testing.T) { tx, err := cache.Begin(txConn) NoError(t, err) - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLogin Error(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) }) @@ -439,7 +470,7 @@ func testSimpleCreate(t *testing.T, typ CacheServerType) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, typ)) userLogin := defaultUserLogin() - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) @@ -459,7 +490,7 @@ func testSimpleCreate(t *testing.T, typ CacheServerType) { if userLogin.ID != 1001 { t.Fatal("cannot assign id") } - builder := NewQueryBuilder("user_logins").Eq("user_id", uint64(2)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("user_id", uint64(2)) var foundUserLogin UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &foundUserLogin)) @@ -478,7 +509,7 @@ func TestSimpleUpdate(t *testing.T) { func testSimpleUpdate(t *testing.T, typ CacheServerType) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) txConn, err := conn.Begin() @@ -486,7 +517,7 @@ func testSimpleUpdate(t *testing.T, typ CacheServerType) { tx, err := cache.Begin(txConn) NoError(t, err) - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) @@ -514,7 +545,7 @@ func testSimpleDelete(t *testing.T, typ CacheServerType) { NoError(t, initCache(conn, typ)) userLogin := defaultUserLogin() - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) txConn, err := conn.Begin() @@ -522,7 +553,7 @@ func testSimpleDelete(t *testing.T, typ CacheServerType) { tx, err := cache.Begin(txConn) NoError(t, err) - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var v UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) @@ -544,7 +575,7 @@ func testCreateWithoutCache(t *testing.T, typ CacheServerType) { NoError(t, initCache(conn, typ)) userLogin := defaultUserLogin() - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) @@ -564,7 +595,7 @@ func testCreateWithoutCache(t *testing.T, typ CacheServerType) { if userLogin.ID != 1001 { t.Fatal("cannot insert record") } - builder := NewQueryBuilder("user_logins").Eq("user_id", uint64(3)).Eq("user_session_id", uint64(2)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("user_id", uint64(3)).Eq("user_session_id", uint64(2)) var foundUserLogin UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &foundUserLogin)) @@ -583,9 +614,9 @@ func TestQueryBuilder(t *testing.T) { func testQueryBuilder(t *testing.T, typ CacheServerType) { t.Run("WHERE IN AND EQ query", func(t *testing.T) { NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("user_session_id", uint64(1)) queries, err := builder.BuildWithIndex(slc.valueFactory, slc.indexes, slc.typ) @@ -594,17 +625,23 @@ func testQueryBuilder(t *testing.T, typ CacheServerType) { return server.ErrCacheMiss })) query, _ := queries.CacheMissQueriesToSQL(slc.typ) - if query != "SELECT `id`,`user_id`,`user_session_id`,`login_param_id`,`name`,`created_at`,`updated_at` FROM `user_logins` WHERE `user_id` IN (?,?,?,?,?) AND `user_session_id` = ?" { - t.Fatal("invalid query") + if driver.DBType == database.MySQL { + if query != "SELECT `id`,`user_id`,`user_session_id`,`login_param_id`,`name`,`created_at`,`updated_at` FROM `user_logins` WHERE `user_id` IN (?,?,?,?,?) AND `user_session_id` = ?" { + t.Fatal("invalid query") + } + } else { + if query != `SELECT "id","user_id","user_session_id","login_param_id","name","created_at","updated_at" FROM "user_logins" WHERE "user_id" IN ($1,$2,$3,$4,$5) AND "user_session_id" = $6` { + t.Fatal("invalid query") + } } }) t.Run("IS NULL query", func(t *testing.T) { NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("created_at", nil) queries, err := builder.BuildWithIndex(slc.valueFactory, slc.indexes, slc.typ) @@ -613,8 +650,15 @@ func testQueryBuilder(t *testing.T, typ CacheServerType) { return server.ErrCacheMiss })) query, _ := queries.CacheMissQueriesToSQL(slc.typ) - if query != "SELECT `id`,`user_id`,`user_session_id`,`login_param_id`,`name`,`created_at`,`updated_at` FROM `user_logins` WHERE `user_id` IN (?,?,?,?,?) AND `created_at` IS NULL" { - t.Fatal("invalid query") + + if driver.DBType == database.MySQL { + if query != "SELECT `id`,`user_id`,`user_session_id`,`login_param_id`,`name`,`created_at`,`updated_at` FROM `user_logins` WHERE `user_id` IN (?,?,?,?,?) AND `created_at` IS NULL" { + t.Fatal("invalid query") + } + } else { + if query != `SELECT "id","user_id","user_session_id","login_param_id","name","created_at","updated_at" FROM "user_logins" WHERE "user_id" IN ($1,$2,$3,$4,$5) AND "created_at" IS NULL` { + t.Fatal("invalid query") + } } }) } @@ -630,11 +674,11 @@ func testFindByQueryBuilder(t *testing.T, typ CacheServerType) { t.Run("find by index column query builder", func(t *testing.T) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("login_param_id", uint64(1)) @@ -665,7 +709,7 @@ func testFindByQueryBuilder(t *testing.T, typ CacheServerType) { t.Fatal("cannot work FindByQueryBuilder") } t.Run("duplicated values", func(t *testing.T) { - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5, 1, 2, 3, 4, 5}). Eq("login_param_id", uint64(1)) var userLogins UserLogins @@ -690,11 +734,11 @@ func testFindByQueryBuilder(t *testing.T, typ CacheServerType) { t.Run("cache miss query, find from db", func(t *testing.T) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("user_session_id", uint64(1)) txConn, err := conn.Begin() @@ -713,7 +757,7 @@ func testFindByQueryBuilder(t *testing.T, typ CacheServerType) { t.Run("partially found pk and value in cache with uq key", func(t *testing.T) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) { txConn, err := conn.Begin() @@ -725,7 +769,7 @@ func testFindByQueryBuilder(t *testing.T, typ CacheServerType) { NoError(t, tx.CommitCacheOnly()) } - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5, 6}). Eq("user_session_id", uint64(1)) txConn, err := conn.Begin() @@ -742,7 +786,7 @@ func testFindByQueryBuilder(t *testing.T, typ CacheServerType) { t.Run("partially found pk and value in cache with idx key", func(t *testing.T) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) { txConn, err := conn.Begin() @@ -785,7 +829,7 @@ func testFindByQueryBuilder(t *testing.T, typ CacheServerType) { NoError(t, tx.Commit()) } { - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(5)). In("login_param_id", []uint64{1, 2}) txConn, err := conn.Begin() @@ -811,7 +855,7 @@ func testFindByQueryBuilder(t *testing.T, typ CacheServerType) { NoError(t, tx.CommitCacheOnly()) } - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(5)). In("login_param_id", []uint64{1, 2}) txConn, err := conn.Begin() @@ -829,7 +873,7 @@ func testFindByQueryBuilder(t *testing.T, typ CacheServerType) { t.Run("find after updated index column value in same tx", func(t *testing.T) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) @@ -840,7 +884,7 @@ func testFindByQueryBuilder(t *testing.T, typ CacheServerType) { var userLogin *UserLogin { - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("user_session_id", uint64(1)) var userLogins UserLogins @@ -851,14 +895,14 @@ func testFindByQueryBuilder(t *testing.T, typ CacheServerType) { } userLogin = userLogins[0] } - updateBuilder := NewQueryBuilder("user_logins").Eq("id", userLogin.ID) + updateBuilder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", userLogin.ID) updateMap := map[string]interface{}{ "login_param_id": uint64(5), } NoError(t, slc.UpdateByQueryBuilder(context.Background(), tx, updateBuilder, updateMap)) { - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", userLogin.UserID). Eq("login_param_id", uint64(5)) var userLogins UserLogins @@ -892,16 +936,15 @@ func testUpdateByQueryBuilder(t *testing.T, typ CacheServerType) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, typ)) s := "user_id" - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{shardKey: &s}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{shardKey: &s}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) - fmt.Println("AAAA", slc.opt) t.Run("available cache", func(t *testing.T) { txConn, err := conn.Begin() NoError(t, err) tx, err := cache.Begin(txConn) NoError(t, err) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("user_session_id", uint64(1)) name := fmt.Sprintf("rapidash_%d", 2) @@ -917,7 +960,7 @@ func testUpdateByQueryBuilder(t *testing.T, typ CacheServerType) { NoError(t, tx.Commit()) }) t.Run("unavailable cache", func(t *testing.T) { - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Gte("user_id", uint64(6)). Lte("user_id", uint64(10)) t.Run("update without cache", func(t *testing.T) { @@ -934,7 +977,7 @@ func testUpdateByQueryBuilder(t *testing.T, typ CacheServerType) { NoError(t, slc.UpdateByQueryBuilder(context.Background(), tx, builder, updateParam)) var newUserLogins UserLogins - findBuilder := NewQueryBuilder("user_logins").In("user_id", []uint64{6, 7, 8, 9, 10}) + findBuilder := NewQueryBuilder("user_logins", driver.Adapter).In("user_id", []uint64{6, 7, 8, 9, 10}) NoError(t, slc.FindByQueryBuilder(context.Background(), tx, findBuilder, &newUserLogins)) Equal(t, len(newUserLogins), 5) @@ -959,7 +1002,7 @@ func testUpdateByQueryBuilder(t *testing.T, typ CacheServerType) { NoError(t, slc.UpdateByQueryBuilder(context.Background(), tx, builder, updateParam)) var newUserLogins UserLogins - findBuilder := NewQueryBuilder("user_logins").In("user_id", []uint64{6, 7, 8, 9, 10}) + findBuilder := NewQueryBuilder("user_logins", driver.Adapter).In("user_id", []uint64{6, 7, 8, 9, 10}) NoError(t, slc.FindByQueryBuilder(context.Background(), tx, findBuilder, &newUserLogins)) Equal(t, len(newUserLogins), 5) for _, userLogin := range newUserLogins { @@ -982,11 +1025,11 @@ func TestUpdateUniqueKeyColumn(t *testing.T) { func testUpdateUniqueKeyColumn(t *testing.T, typ CacheServerType) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("login_param_id", uint64(1)) txConn, err := conn.Begin() @@ -998,7 +1041,7 @@ func testUpdateUniqueKeyColumn(t *testing.T, typ CacheServerType) { } NoError(t, slc.UpdateByQueryBuilder(context.Background(), tx, builder, updateParam)) { - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("login_param_id", uint64(10)) var newUserLogins UserLogins @@ -1019,10 +1062,10 @@ func TestUpdateKeyColumn(t *testing.T) { func testUpdateKeyColumn(t *testing.T, typ CacheServerType) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("user_session_id", uint64(1)) txConn, err := conn.Begin() @@ -1034,7 +1077,7 @@ func testUpdateKeyColumn(t *testing.T, typ CacheServerType) { } NoError(t, slc.UpdateByQueryBuilder(context.Background(), tx, builder, updateParam)) { - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("user_session_id", uint64(10)) var newUserLogins UserLogins @@ -1049,7 +1092,7 @@ func testUpdateKeyColumn(t *testing.T, typ CacheServerType) { func TestUniqueIndexColumnUpdateByPrimaryKey(t *testing.T) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, CacheServerTypeMemcached)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) @@ -1060,7 +1103,7 @@ func TestUniqueIndexColumnUpdateByPrimaryKey(t *testing.T) { tx, err := cache.Begin(txConn) NoError(t, err) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("user_session_id", uint64(1)) @@ -1071,7 +1114,7 @@ func TestUniqueIndexColumnUpdateByPrimaryKey(t *testing.T) { t.Fatal("failed to get value by index key") } - builder = NewQueryBuilder("user_logins"). + builder = NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("user_session_id", uint64(2)) @@ -1093,7 +1136,7 @@ func TestUniqueIndexColumnUpdateByPrimaryKey(t *testing.T) { updateParam := map[string]interface{}{ "user_session_id": uint64(2), } - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) NoError(t, slc.UpdateByQueryBuilder(context.Background(), tx, builder, updateParam)) NoError(t, tx.Commit()) } @@ -1104,7 +1147,7 @@ func TestUniqueIndexColumnUpdateByPrimaryKey(t *testing.T) { NoError(t, err) tx, err := cache.Begin(txConn) NoError(t, err) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("user_session_id", uint64(1)) @@ -1122,7 +1165,7 @@ func TestUniqueIndexColumnUpdateByPrimaryKey(t *testing.T) { NoError(t, err) tx, err := cache.Begin(txConn) NoError(t, err) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("user_session_id", uint64(2)) @@ -1138,7 +1181,7 @@ func TestUniqueIndexColumnUpdateByPrimaryKey(t *testing.T) { func TestIndexColumnUpdateByPrimaryKey(t *testing.T) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, CacheServerTypeMemcached)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) @@ -1149,7 +1192,7 @@ func TestIndexColumnUpdateByPrimaryKey(t *testing.T) { tx, err := cache.Begin(txConn) NoError(t, err) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("login_param_id", uint64(1)) @@ -1160,7 +1203,7 @@ func TestIndexColumnUpdateByPrimaryKey(t *testing.T) { t.Fatal("failed to get value by index key") } - builder = NewQueryBuilder("user_logins"). + builder = NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("login_param_id", uint64(2)) @@ -1182,7 +1225,7 @@ func TestIndexColumnUpdateByPrimaryKey(t *testing.T) { updateParam := map[string]interface{}{ "login_param_id": uint64(2), } - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) NoError(t, slc.UpdateByQueryBuilder(context.Background(), tx, builder, updateParam)) NoError(t, tx.Commit()) } @@ -1193,7 +1236,7 @@ func TestIndexColumnUpdateByPrimaryKey(t *testing.T) { NoError(t, err) tx, err := cache.Begin(txConn) NoError(t, err) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("login_param_id", uint64(1)) @@ -1211,7 +1254,7 @@ func TestIndexColumnUpdateByPrimaryKey(t *testing.T) { NoError(t, err) tx, err := cache.Begin(txConn) NoError(t, err) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("login_param_id", uint64(2)) @@ -1227,7 +1270,7 @@ func TestIndexColumnUpdateByPrimaryKey(t *testing.T) { func TestLockingRead(t *testing.T) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, CacheServerTypeMemcached)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.cacheServer.Flush()) NoError(t, slc.WarmUp(conn)) @@ -1238,7 +1281,7 @@ func TestLockingRead(t *testing.T) { // store cache to stash { - builder := NewQueryBuilder("user_logins").Eq("user_id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("user_id", uint64(1)) var userLogin UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &userLogin)) @@ -1259,7 +1302,7 @@ func TestLockingRead(t *testing.T) { updateParam := map[string]interface{}{ "login_param_id": uint64(2), } - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) NoError(t, slc.UpdateByQueryBuilder(context.Background(), tx, builder, updateParam)) NoError(t, tx.Commit()) } @@ -1268,7 +1311,7 @@ func TestLockingRead(t *testing.T) { // in this case, cannot get updated value in normal query, // but if use locking read query, could read updated value. { - builder := NewQueryBuilder("user_logins").Eq("user_id", uint64(1)).ForUpdate() + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("user_id", uint64(1)).ForUpdate() var userLogin UserLogin NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &userLogin)) @@ -1291,11 +1334,11 @@ func TestDeleteByQueryBuilder(t *testing.T) { func testDeleteByQueryBuilder(t *testing.T, typ CacheServerType) { NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) t.Run("cache is available", func(t *testing.T) { NoError(t, initUserLoginTable(conn)) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("user_session_id", uint64(1)) txConn, err := conn.Begin() @@ -1308,7 +1351,7 @@ func testDeleteByQueryBuilder(t *testing.T, typ CacheServerType) { t.Run("not available cache", func(t *testing.T) { NoError(t, initUserLoginTable(conn)) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Gte("user_session_id", uint64(1)). Lte("user_session_id", uint64(3)) txConn, err := conn.Begin() @@ -1318,7 +1361,7 @@ func testDeleteByQueryBuilder(t *testing.T, typ CacheServerType) { NoError(t, slc.DeleteByQueryBuilder(context.Background(), tx, builder)) var userLogins UserLogins - findBuilder := NewQueryBuilder("user_logins"). + findBuilder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). In("user_session_id", []uint64{1, 2, 3}) NoError(t, slc.FindByQueryBuilder(context.Background(), tx, findBuilder, &userLogins)) @@ -1330,7 +1373,7 @@ func testDeleteByQueryBuilder(t *testing.T, typ CacheServerType) { t.Run("delete by primary keys", func(t *testing.T) { NoError(t, initUserLoginTable(conn)) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("id", []uint64{1, 2, 3, 4, 5}) txConn, err := conn.Begin() NoError(t, err) @@ -1357,7 +1400,7 @@ func TestRawQuery(t *testing.T) { func testRawQuery(t *testing.T, typ CacheServerType) { NoError(t, initUserLoginTable(conn)) NoError(t, initCache(conn, typ)) - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) txConn, err := conn.Begin() @@ -1367,8 +1410,9 @@ func testRawQuery(t *testing.T, typ CacheServerType) { defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() t.Run("raw query", func(t *testing.T) { - builder := NewQueryBuilder("user_logins"). - SQL("ORDER BY id DESC LIMIT ? OFFSET ?", 3, 1) + qh := driver.Adapter.QueryHelper() + builder := NewQueryBuilder("user_logins", driver.Adapter). + SQL(fmt.Sprintf("ORDER BY id DESC LIMIT %s OFFSET %s", qh.Placeholder(), qh.Placeholder()), 3, 1) var userLogins UserLogins NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &userLogins)) if len(userLogins) != 3 && @@ -1379,7 +1423,7 @@ func testRawQuery(t *testing.T, typ CacheServerType) { } }) t.Run("all query", func(t *testing.T) { - builder := NewQueryBuilder("user_logins") + builder := NewQueryBuilder("user_logins", driver.Adapter) var userLogins UserLogins NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &userLogins)) if len(userLogins) != 1000 { @@ -1409,7 +1453,9 @@ type PtrType struct { } func (p *PtrType) EncodeRapidash(enc Encoder) error { - enc.Uint64("id", p.id) + if p.id != 0 { + enc.Uint64("id", p.id) + } enc.IntPtr("intptr", p.intPtr) enc.Int8Ptr("int8ptr", p.int8Ptr) enc.Int16Ptr("int16ptr", p.int16Ptr) @@ -1575,7 +1621,7 @@ func validateNotNilValue(t *testing.T, v *PtrType) { func TestPointerType(t *testing.T) { NoError(t, initCache(conn, CacheServerTypeMemcached)) - slc := NewSecondLevelCache(new(PtrType).Type(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(new(PtrType).Type(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) t.Run("invalid value", func(t *testing.T) { @@ -1585,7 +1631,7 @@ func TestPointerType(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.Rollback()) }() - builder := NewQueryBuilder("ptr").Eq("id", uint64(1)) + builder := NewQueryBuilder("ptr", driver.Adapter).Eq("id", uint64(1)) var v PtrType NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) if v.id != 1 { @@ -1600,7 +1646,7 @@ func TestPointerType(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.Rollback()) }() - builder := NewQueryBuilder("ptr").Eq("id", uint64(2)) + builder := NewQueryBuilder("ptr", driver.Adapter).Eq("id", uint64(2)) var v PtrType NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &v)) if v.id != 2 { @@ -1621,7 +1667,7 @@ func TestPointerType(t *testing.T) { t.Fatal("cannot insert invalid value") } var foundValue PtrType - NoError(t, slc.FindByQueryBuilder(context.Background(), tx, NewQueryBuilder("ptr").Eq("id", uint64(id)), &foundValue)) + NoError(t, slc.FindByQueryBuilder(context.Background(), tx, NewQueryBuilder("ptr", driver.Adapter).Eq("id", uint64(id)), &foundValue)) // set invalid value to cache server NoError(t, tx.Commit()) @@ -1633,7 +1679,7 @@ func TestPointerType(t *testing.T) { defer func() { NoError(t, tx.Rollback()) }() var foundValue PtrType - NoError(t, slc.FindByQueryBuilder(context.Background(), tx, NewQueryBuilder("ptr").Eq("id", uint64(id)), &foundValue)) + NoError(t, slc.FindByQueryBuilder(context.Background(), tx, NewQueryBuilder("ptr", driver.Adapter).Eq("id", uint64(id)), &foundValue)) }) }) t.Run("update valid value", func(t *testing.T) { @@ -1662,12 +1708,12 @@ func TestPointerType(t *testing.T) { defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() var foundValue PtrType - NoError(t, slc.FindByQueryBuilder(context.Background(), tx, NewQueryBuilder("ptr").Eq("id", uint64(1)), &foundValue)) + NoError(t, slc.FindByQueryBuilder(context.Background(), tx, NewQueryBuilder("ptr", driver.Adapter).Eq("id", uint64(1)), &foundValue)) if foundValue.id == 0 { t.Fatal("cannot find value") } - builder := NewQueryBuilder("ptr").Eq("id", foundValue.id) + builder := NewQueryBuilder("ptr", driver.Adapter).Eq("id", foundValue.id) t.Run("not pointer value map", func(t *testing.T) { updateMap := map[string]interface{}{ "intptr": intValue, @@ -1715,6 +1761,7 @@ func TestPointerType(t *testing.T) { }) t.Run("some queries", func(t *testing.T) { + columns := []string{ "intptr", "int8ptr", @@ -1733,44 +1780,51 @@ func TestPointerType(t *testing.T) { "stringptr", "timeptr", } - txConn, err := conn.Begin() - NoError(t, err) + { + txConn, err := conn.Begin() + NoError(t, err) - for idx, column := range columns { - if _, err := txConn.Exec(fmt.Sprintf("ALTER TABLE `ptr` ADD INDEX idx_%d(%s)", idx+1, column)); err != nil { - t.Fatalf("%+v", err) + alterfmt := map[database.DBType]string{ + database.MySQL: "ALTER TABLE ptr ADD INDEX idx_%d(%s)", + database.Postgres: "CREATE INDEX idx_%d ON ptr (%s)", } + for idx, column := range columns { + if _, err := txConn.Exec(fmt.Sprintf(alterfmt[driver.DBType], idx+1, column)); err != nil { + t.Fatalf("%+v", err) + } + } + NoError(t, txConn.Commit()) } - fmt.Println("ALTER END") + txConn, err := conn.Begin() + NoError(t, err) NoError(t, slc.WarmUp(conn)) fmt.Println("WARM UP END") tx, err := cache.Begin(txConn) fmt.Println("BEGIN END") NoError(t, err) - defer func() { NoError(t, tx.Rollback()) }() var ptr PtrType - builder := NewQueryBuilder("ptr").Eq("id", uint64(2)) + builder := NewQueryBuilder("ptr", driver.Adapter).Eq("id", uint64(2)) NoError(t, slc.FindByQueryBuilder(context.Background(), tx, builder, &ptr)) t.Run("pointer value query", func(t *testing.T) { builders := []*QueryBuilder{ - NewQueryBuilder("ptr").Eq("id", &ptr.id), - NewQueryBuilder("ptr").Eq("intptr", ptr.intPtr), - NewQueryBuilder("ptr").Eq("int8ptr", ptr.int8Ptr), - NewQueryBuilder("ptr").Eq("int16ptr", ptr.int16Ptr), - NewQueryBuilder("ptr").Eq("int32ptr", ptr.int32Ptr), - NewQueryBuilder("ptr").Eq("int64ptr", ptr.int64Ptr), - NewQueryBuilder("ptr").Eq("uintptr", ptr.uintPtr), - NewQueryBuilder("ptr").Eq("uint8ptr", ptr.uint8Ptr), - NewQueryBuilder("ptr").Eq("uint16ptr", ptr.uint16Ptr), - NewQueryBuilder("ptr").Eq("uint32ptr", ptr.uint32Ptr), - NewQueryBuilder("ptr").Eq("uint64ptr", ptr.uint64Ptr), - NewQueryBuilder("ptr").Eq("float32ptr", ptr.float32Ptr), - NewQueryBuilder("ptr").Eq("float64ptr", ptr.float64Ptr), - NewQueryBuilder("ptr").Eq("boolptr", ptr.boolPtr), - NewQueryBuilder("ptr").Eq("bytesptr", ptr.bytesPtr), - NewQueryBuilder("ptr").Eq("stringptr", ptr.stringPtr), - NewQueryBuilder("ptr").Eq("timeptr", ptr.timePtr), + NewQueryBuilder("ptr", driver.Adapter).Eq("id", &ptr.id), + NewQueryBuilder("ptr", driver.Adapter).Eq("intptr", ptr.intPtr), + NewQueryBuilder("ptr", driver.Adapter).Eq("int8ptr", ptr.int8Ptr), + NewQueryBuilder("ptr", driver.Adapter).Eq("int16ptr", ptr.int16Ptr), + NewQueryBuilder("ptr", driver.Adapter).Eq("int32ptr", ptr.int32Ptr), + NewQueryBuilder("ptr", driver.Adapter).Eq("int64ptr", ptr.int64Ptr), + NewQueryBuilder("ptr", driver.Adapter).Eq("uintptr", ptr.uintPtr), + NewQueryBuilder("ptr", driver.Adapter).Eq("uint8ptr", ptr.uint8Ptr), + NewQueryBuilder("ptr", driver.Adapter).Eq("uint16ptr", ptr.uint16Ptr), + NewQueryBuilder("ptr", driver.Adapter).Eq("uint32ptr", ptr.uint32Ptr), + NewQueryBuilder("ptr", driver.Adapter).Eq("uint64ptr", ptr.uint64Ptr), + NewQueryBuilder("ptr", driver.Adapter).Eq("float32ptr", ptr.float32Ptr), + NewQueryBuilder("ptr", driver.Adapter).Eq("float64ptr", ptr.float64Ptr), + NewQueryBuilder("ptr", driver.Adapter).Eq("boolptr", ptr.boolPtr), + NewQueryBuilder("ptr", driver.Adapter).Eq("bytesptr", ptr.bytesPtr), + NewQueryBuilder("ptr", driver.Adapter).Eq("stringptr", ptr.stringPtr), + NewQueryBuilder("ptr", driver.Adapter).Eq("timeptr", ptr.timePtr), } for _, builder := range builders { var v PtrType @@ -1780,22 +1834,22 @@ func TestPointerType(t *testing.T) { }) t.Run("IN condition query", func(t *testing.T) { builders := []*QueryBuilder{ - NewQueryBuilder("ptr").In("intptr", []int{1}), - NewQueryBuilder("ptr").In("int8ptr", []int8{2}), - NewQueryBuilder("ptr").In("int16ptr", []int16{3}), - NewQueryBuilder("ptr").In("int32ptr", []int32{4}), - NewQueryBuilder("ptr").In("int64ptr", []int64{5}), - NewQueryBuilder("ptr").In("uintptr", []uint{6}), - NewQueryBuilder("ptr").In("uint8ptr", []uint8{7}), - NewQueryBuilder("ptr").In("uint16ptr", []uint16{8}), - NewQueryBuilder("ptr").In("uint32ptr", []uint32{9}), - NewQueryBuilder("ptr").In("uint64ptr", []uint64{10}), - NewQueryBuilder("ptr").In("float32ptr", []float32{1.23}), - NewQueryBuilder("ptr").In("float64ptr", []float64{4.56}), - NewQueryBuilder("ptr").In("boolptr", []bool{true}), - NewQueryBuilder("ptr").In("bytesptr", [][]byte{[]byte("bytes")}), - NewQueryBuilder("ptr").In("stringptr", []string{"string"}), - NewQueryBuilder("ptr").In("timeptr", []time.Time{*ptr.timePtr}), + NewQueryBuilder("ptr", driver.Adapter).In("intptr", []int{1}), + NewQueryBuilder("ptr", driver.Adapter).In("int8ptr", []int8{2}), + NewQueryBuilder("ptr", driver.Adapter).In("int16ptr", []int16{3}), + NewQueryBuilder("ptr", driver.Adapter).In("int32ptr", []int32{4}), + NewQueryBuilder("ptr", driver.Adapter).In("int64ptr", []int64{5}), + NewQueryBuilder("ptr", driver.Adapter).In("uintptr", []uint{6}), + NewQueryBuilder("ptr", driver.Adapter).In("uint8ptr", []uint8{7}), + NewQueryBuilder("ptr", driver.Adapter).In("uint16ptr", []uint16{8}), + NewQueryBuilder("ptr", driver.Adapter).In("uint32ptr", []uint32{9}), + NewQueryBuilder("ptr", driver.Adapter).In("uint64ptr", []uint64{10}), + NewQueryBuilder("ptr", driver.Adapter).In("float32ptr", []float32{1.23}), + NewQueryBuilder("ptr", driver.Adapter).In("float64ptr", []float64{4.56}), + NewQueryBuilder("ptr", driver.Adapter).In("boolptr", []bool{true}), + NewQueryBuilder("ptr", driver.Adapter).In("bytesptr", [][]byte{[]byte("bytes")}), + NewQueryBuilder("ptr", driver.Adapter).In("stringptr", []string{"string"}), + NewQueryBuilder("ptr", driver.Adapter).In("timeptr", []time.Time{*ptr.timePtr}), } for _, builder := range builders { var v PtrType @@ -1803,10 +1857,21 @@ func TestPointerType(t *testing.T) { NotEqualf(t, v.id, uint64(0), "cannot find by IN query") } }) - for idx := range columns { - if _, err := txConn.Exec(fmt.Sprintf("ALTER TABLE `ptr` DROP INDEX idx_%d", idx+1)); err != nil { - t.Fatalf("%+v", err) + NoError(t, tx.Rollback()) + { + txConn, err := conn.Begin() + NoError(t, err) + + alterfmt := map[database.DBType]string{ + database.MySQL: "ALTER TABLE `ptr` DROP INDEX idx_%d", + database.Postgres: "DROP INDEX idx_%d", + } + for idx := range columns { + if _, err := txConn.Exec(fmt.Sprintf(alterfmt[driver.DBType], idx+1)); err != nil { + t.Fatalf("%+v", err) + } } + NoError(t, txConn.Commit()) } }) } @@ -1838,11 +1903,11 @@ func BenchmarkSLCIN_SimpleMemcachedAccess(b *testing.B) { panic(err) } setNopLogger() - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) if err := slc.WarmUp(conn); err != nil { panic(err) } - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("user_session_id", uint64(1)) tx, err := cache.Begin(conn) @@ -1884,11 +1949,11 @@ func BenchmarkSLCIN_SimpleRedisAccess(b *testing.B) { panic(err) } setNopLogger() - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) if err := slc.WarmUp(conn); err != nil { panic(err) } - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("user_session_id", uint64(1)) tx, err := cache.Begin(conn) @@ -1986,12 +2051,12 @@ func benchmarkSLCINRapidash(b *testing.B) { panic(err) } setNopLogger() - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) if err := slc.WarmUp(conn); err != nil { panic(err) } b.ResetTimer() - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). In("user_id", []uint64{1, 2, 3, 4, 5}). Eq("user_session_id", uint64(1)) userLogins := []*UserLogin{} @@ -2015,9 +2080,9 @@ func benchmarkSLCINRapidash(b *testing.B) { } func TestCountQuerySLC(t *testing.T) { - slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(userLoginType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) - builder := NewQueryBuilder("user_logins"). + builder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)) tx, err := cache.Begin(conn) if err != nil { @@ -2031,9 +2096,9 @@ func TestCountQuerySLC(t *testing.T) { } func TestCountByQueryBuilderCaseDatabaseRecordIsEmptySLC(t *testing.T) { - slc := NewSecondLevelCache(emptyType(), cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(emptyType(), cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) - builder := NewQueryBuilder("empties").Eq("id", uint64(1)) + builder := NewQueryBuilder("empties", driver.Adapter).Eq("id", uint64(1)) tx, err := cache.Begin(conn) if err != nil { panic(err) @@ -2046,19 +2111,12 @@ func TestCountByQueryBuilderCaseDatabaseRecordIsEmptySLC(t *testing.T) { } func TestWarmUp(t *testing.T) { - _, err := conn.Exec("DROP TABLE IF EXISTS warm_up_users") + NoError(t, initTable(conn, "warm_up_users")) + f, err := os.Open(filepath.Join("testdata", driver.Name, "alter_warm_up_users.sql")) NoError(t, err) + defer f.Close() + queryScanner := bufio.NewScanner(f) - sql := ` - CREATE TABLE IF NOT EXISTS warm_up_users ( - id bigint(20) unsigned NOT NULL AUTO_INCREMENT, - user_id bigint(20) unsigned NOT NULL, - nickname varchar(255) NOT NULL, - age int(10) NOT NULL, - created_at datetime NOT NULL, - PRIMARY KEY (id) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8 -` strc := NewStruct("warm_up_users"). FieldUint64("id"). FieldUint64("user_id"). @@ -2066,11 +2124,8 @@ func TestWarmUp(t *testing.T) { FieldUint64("age"). FieldUint64("created_at") - _, err = conn.Exec(sql) - NoError(t, err) - t.Run("only a single pk", func(t *testing.T) { - slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) Equal(t, len(slc.indexes), 1) @@ -2082,7 +2137,7 @@ func TestWarmUp(t *testing.T) { t.Run("with shard_key", func(t *testing.T) { shardKey := "user_id" - slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{shardKey: &shardKey}) + slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{shardKey: &shardKey}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) Equal(t, len(slc.indexes), 1) @@ -2096,9 +2151,10 @@ func TestWarmUp(t *testing.T) { }) t.Run("pk multiple pair", func(t *testing.T) { - _, err := conn.Exec("ALTER TABLE warm_up_users DROP PRIMARY KEY, ADD PRIMARY KEY (id, created_at)") + queryScanner.Scan() + _, err := conn.Exec(queryScanner.Text()) NoError(t, err) - slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) Equal(t, len(slc.indexes), 2) { @@ -2119,7 +2175,7 @@ func TestWarmUp(t *testing.T) { t.Run("with shard_key", func(t *testing.T) { shardKey := "user_id" - slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{shardKey: &shardKey}) + slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{shardKey: &shardKey}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) Equal(t, len(slc.indexes), 2) @@ -2144,9 +2200,10 @@ func TestWarmUp(t *testing.T) { }) t.Run("index key", func(t *testing.T) { - _, err := conn.Exec("ALTER TABLE warm_up_users DROP PRIMARY KEY, ADD PRIMARY KEY (id), ADD INDEX idx_user_id_nickname(user_id, nickname)") + queryScanner.Scan() + _, err := conn.Exec(queryScanner.Text()) NoError(t, err) - slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) Equal(t, len(slc.indexes), 3) { @@ -2174,9 +2231,10 @@ func TestWarmUp(t *testing.T) { }) t.Run("unique key", func(t *testing.T) { - _, err := conn.Exec("ALTER TABLE warm_up_users DROP INDEX idx_user_id_nickname, ADD UNIQUE uq_user_id_nickname(user_id, nickname)") + queryScanner.Scan() + _, err := conn.Exec(queryScanner.Text()) NoError(t, err) - slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{}) + slc := NewSecondLevelCache(strc, cache.cacheServer, TableOption{}, database.NewAdapterWithDBType(driver.DBType)) NoError(t, slc.WarmUp(conn)) Equal(t, len(slc.indexes), 3) { diff --git a/testdata/mysql/alter_user_logins.sql b/testdata/mysql/alter_user_logins.sql new file mode 100644 index 0000000..c5f7e01 --- /dev/null +++ b/testdata/mysql/alter_user_logins.sql @@ -0,0 +1,3 @@ +ALTER TABLE user_logins ADD password varchar(10) DEFAULT '100'; +ALTER TABLE user_logins MODIFY COLUMN password int(20) unsigned; +ALTER TABLE user_logins DROP COLUMN password; diff --git a/testdata/mysql/alter_warm_up_users.sql b/testdata/mysql/alter_warm_up_users.sql new file mode 100644 index 0000000..a170270 --- /dev/null +++ b/testdata/mysql/alter_warm_up_users.sql @@ -0,0 +1,4 @@ +ALTER TABLE warm_up_users DROP PRIMARY KEY, ADD PRIMARY KEY (id, created_at); +ALTER TABLE warm_up_users DROP PRIMARY KEY, ADD PRIMARY KEY (id), ADD INDEX idx_user_id_nickname(user_id, nickname); +ALTER TABLE warm_up_users DROP INDEX idx_user_id_nickname, ADD UNIQUE uq_user_id_nickname(user_id, nickname); + diff --git a/testdata/mysql/empties.sql b/testdata/mysql/empties.sql new file mode 100644 index 0000000..f958c09 --- /dev/null +++ b/testdata/mysql/empties.sql @@ -0,0 +1,6 @@ +DROP TABLE IF EXISTS empties; + +CREATE TABLE IF NOT EXISTS empties ( + id bigint(20) unsigned NOT NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; diff --git a/testdata/mysql/events.sql b/testdata/mysql/events.sql new file mode 100644 index 0000000..e7b9fd2 --- /dev/null +++ b/testdata/mysql/events.sql @@ -0,0 +1,15 @@ +DROP TABLE IF EXISTS events; + +CREATE TABLE IF NOT EXISTS events ( + id bigint(20) unsigned NOT NULL, + event_id bigint(20) unsigned NOT NULL, + event_category_id bigint(20) unsigned NOT NULL, + term enum('early_morning', 'morning', 'daytime', 'evening', 'night', 'midnight') NOT NULL, + start_week int(10) unsigned NOT NULL, + end_week int(10) unsigned NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + PRIMARY KEY (id), + UNIQUE KEY (event_id, start_week), + KEY (term, start_week, end_week) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; diff --git a/testdata/mysql/ptr.sql b/testdata/mysql/ptr.sql new file mode 100644 index 0000000..135ec00 --- /dev/null +++ b/testdata/mysql/ptr.sql @@ -0,0 +1,22 @@ +DROP TABLE IF EXISTS ptr; + +CREATE TABLE IF NOT EXISTS ptr ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + intptr int, + int8ptr int, + int16ptr int, + int32ptr int, + int64ptr int, + uintptr int unsigned, + uint8ptr int unsigned, + uint16ptr int unsigned, + uint32ptr int unsigned, + uint64ptr bigint unsigned, + float32ptr float, + float64ptr double, + boolptr tinyint, + bytesptr varchar(255), + stringptr varchar(255), + timeptr datetime, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; diff --git a/testdata/mysql/user_logins.sql b/testdata/mysql/user_logins.sql new file mode 100644 index 0000000..3b68b4e --- /dev/null +++ b/testdata/mysql/user_logins.sql @@ -0,0 +1,15 @@ +DROP TABLE IF EXISTS user_logins; + +CREATE TABLE IF NOT EXISTS user_logins ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + user_id bigint(20) unsigned NOT NULL, + user_session_id bigint(20) unsigned NOT NULL, + login_param_id bigint(20) unsigned NOT NULL, + name varchar(255) NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + PRIMARY KEY (id), + UNIQUE KEY (user_id, user_session_id), + KEY (user_id, login_param_id), + KEY (user_id, created_at) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; diff --git a/testdata/mysql/user_logs.sql b/testdata/mysql/user_logs.sql new file mode 100644 index 0000000..7fe3b35 --- /dev/null +++ b/testdata/mysql/user_logs.sql @@ -0,0 +1,13 @@ +DROP TABLE IF EXISTS user_logs; + +CREATE TABLE IF NOT EXISTS user_logs ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + user_id bigint(20) unsigned NOT NULL, + content_type varchar(255) NOT NULL, + content_id bigint(20) unsigned NOT NULL, + created_at datetime NOT NULL, + updated_at datetime NOT NULL, + PRIMARY KEY (id), + KEY (user_id, created_at), + KEY (user_id, content_type, content_id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; diff --git a/testdata/mysql/warm_up_users.sql b/testdata/mysql/warm_up_users.sql new file mode 100644 index 0000000..31cce46 --- /dev/null +++ b/testdata/mysql/warm_up_users.sql @@ -0,0 +1,9 @@ +DROP TABLE IF EXISTS warm_up_users; +CREATE TABLE IF NOT EXISTS warm_up_users ( + id bigint(20) unsigned NOT NULL AUTO_INCREMENT, + user_id bigint(20) unsigned NOT NULL, + nickname varchar(255) NOT NULL, + age int(10) NOT NULL, + created_at datetime NOT NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; diff --git a/testdata/postgres/alter_user_logins.sql b/testdata/postgres/alter_user_logins.sql new file mode 100644 index 0000000..0ae83de --- /dev/null +++ b/testdata/postgres/alter_user_logins.sql @@ -0,0 +1,3 @@ +ALTER TABLE user_logins ADD password varchar(10) DEFAULT '100'; +ALTER TABLE user_logins ALTER COLUMN password DROP DEFAULT, ALTER COLUMN password TYPE integer USING (password::integer); +ALTER TABLE user_logins DROP COLUMN password; diff --git a/testdata/postgres/alter_warm_up_users.sql b/testdata/postgres/alter_warm_up_users.sql new file mode 100644 index 0000000..e5802c1 --- /dev/null +++ b/testdata/postgres/alter_warm_up_users.sql @@ -0,0 +1,3 @@ +ALTER TABLE warm_up_users DROP CONSTRAINT warm_up_users_pkey, ADD PRIMARY KEY (id, created_at); +ALTER TABLE warm_up_users DROP CONSTRAINT warm_up_users_pkey, ADD PRIMARY KEY (id); CREATE INDEX idx_user_id_nickname ON warm_up_users (user_id, nickname); +ALTER TABLE warm_up_users ADD CONSTRAINT uq_user_id_nickname UNIQUE (user_id, nickname); DROP INDEX idx_user_id_nickname; diff --git a/testdata/postgres/empties.sql b/testdata/postgres/empties.sql new file mode 100644 index 0000000..96ce783 --- /dev/null +++ b/testdata/postgres/empties.sql @@ -0,0 +1,6 @@ +DROP TABLE IF EXISTS empties; + +CREATE TABLE IF NOT EXISTS empties ( + id bigint NOT NULL, + PRIMARY KEY (id) +); diff --git a/testdata/postgres/events.sql b/testdata/postgres/events.sql new file mode 100644 index 0000000..9ccbe93 --- /dev/null +++ b/testdata/postgres/events.sql @@ -0,0 +1,19 @@ +DROP TABLE IF EXISTS events; + +DROP TYPE IF EXISTS term; +CREATE TYPE term AS ENUM('early_morning', 'morning', 'daytime', 'evening', 'night', 'midnight'); + +CREATE TABLE IF NOT EXISTS events ( + id bigint NOT NULL, + event_id bigint NOT NULL, + event_category_id bigint NOT NULL, + term term NOT NULL, + start_week integer NOT NULL, + end_week integer NOT NULL, + created_at timestamp with time zone NOT NULL, + updated_at timestamp with time zone NOT NULL, + PRIMARY KEY (id), + UNIQUE (event_id, start_week) +); + +CREATE INDEX ON events (term, start_week, end_week); diff --git a/testdata/postgres/ptr.sql b/testdata/postgres/ptr.sql new file mode 100644 index 0000000..d5300e9 --- /dev/null +++ b/testdata/postgres/ptr.sql @@ -0,0 +1,22 @@ +DROP TABLE IF EXISTS ptr; + +CREATE TABLE IF NOT EXISTS ptr ( + id SERIAL, + intptr integer, + int8ptr integer, + int16ptr integer, + int32ptr integer, + int64ptr integer, + uintptr integer, + uint8ptr integer, + uint16ptr integer, + uint32ptr integer, + uint64ptr bigint, + float32ptr REAL, + float64ptr DOUBLE PRECISION, + boolptr BOOLEAN, + bytesptr varchar(255), + stringptr varchar(255), + timeptr timestamp with time zone, + PRIMARY KEY (id) +); diff --git a/testdata/postgres/user_logins.sql b/testdata/postgres/user_logins.sql new file mode 100644 index 0000000..83f140c --- /dev/null +++ b/testdata/postgres/user_logins.sql @@ -0,0 +1,16 @@ +DROP TABLE IF EXISTS user_logins; + +CREATE TABLE IF NOT EXISTS user_logins ( + id SERIAL, + user_id bigint NOT NULL, + user_session_id bigint NOT NULL, + login_param_id bigint NOT NULL, + name varchar(255) NOT NULL, + created_at timestamp with time zone NOT NULL, + updated_at timestamp with time zone NOT NULL, + PRIMARY KEY (id), + UNIQUE (user_id, user_session_id) +); + +CREATE INDEX ON user_logins (user_id, login_param_id); +CREATE INDEX ON user_logins (user_id, created_at); diff --git a/testdata/postgres/user_logs.sql b/testdata/postgres/user_logs.sql new file mode 100644 index 0000000..938cc6b --- /dev/null +++ b/testdata/postgres/user_logs.sql @@ -0,0 +1,14 @@ +DROP TABLE IF EXISTS user_logs; + +CREATE TABLE IF NOT EXISTS user_logs ( + id SERIAL, + user_id bigint NOT NULL, + content_type varchar(255) NOT NULL, + content_id bigint NOT NULL, + created_at timestamp with time zone NOT NULL, + updated_at timestamp with time zone NOT NULL, + PRIMARY KEY (id) +); + +CREATE INDEX ON user_logs (user_id, created_at); +CREATE INDEX ON user_logs (user_id, content_type, content_id); diff --git a/testdata/postgres/warm_up_users.sql b/testdata/postgres/warm_up_users.sql new file mode 100644 index 0000000..a068864 --- /dev/null +++ b/testdata/postgres/warm_up_users.sql @@ -0,0 +1,9 @@ +DROP TABLE IF EXISTS warm_up_users; +CREATE TABLE IF NOT EXISTS warm_up_users ( + id SERIAL NOT NULL, + user_id bigint NOT NULL, + nickname varchar(255) NOT NULL, + age integer NOT NULL, + created_at timestamp with time zone NOT NULL, + PRIMARY KEY (id) +); diff --git a/tx_test.go b/tx_test.go index 7f6187f..ed3db4c 100644 --- a/tx_test.go +++ b/tx_test.go @@ -4,10 +4,17 @@ import ( "context" "database/sql" "encoding/json" + "fmt" + "io/ioutil" "os" + "path/filepath" + "strings" "testing" "time" + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + "go.knocknote.io/rapidash/database" "golang.org/x/xerrors" ) @@ -16,6 +23,28 @@ var ( cache *Rapidash ) +var ( + drivers = map[string]struct { + Name, Source string + DBType database.DBType + Adapter database.Adapter + }{ + "mysql": { + Name: "mysql", + Source: "root:@tcp(localhost:3306)/rapidash?parseTime=true", + DBType: database.MySQL, + Adapter: database.NewAdapterWithDBType(database.MySQL), + }, + "postgres": { + Name: "postgres", + Source: "host=localhost user=root dbname=rapidash sslmode=disable", + DBType: database.Postgres, + Adapter: database.NewAdapterWithDBType(database.Postgres), + }, + } + driver = drivers[os.Getenv("RAPIDASH_DB_DRIVER")] +) + func setUp(conn *sql.DB) error { if err := initDB(); err != nil { return xerrors.Errorf("failed to initDB: %w", err) @@ -52,38 +81,35 @@ func initDB() error { return nil } -func initEventTable(conn *sql.DB) error { - if _, err := conn.Exec("DROP TABLE IF EXISTS events"); err != nil { - return xerrors.Errorf("failed to drop events table: %w", err) +func initTable(conn *sql.DB, tableName string) error { + sql, err := ioutil.ReadFile(filepath.Join("testdata", driver.Name, tableName+".sql")) + if err != nil { + return xerrors.Errorf("failed to read sql file: %w", err) } + queries := strings.Split(string(sql), ";") + for _, query := range queries[:len(queries)-1] { + if _, err := conn.Exec(query); err != nil { + return xerrors.Errorf("failed to exec query: %w", err) + } + } + return nil +} - sql := ` -CREATE TABLE events ( - id bigint(20) unsigned NOT NULL, - event_id bigint(20) unsigned NOT NULL, - event_category_id bigint(20) unsigned NOT NULL, - term enum('early_morning', 'morning', 'daytime', 'evening', 'night', 'midnight') NOT NULL, - start_week int(10) unsigned NOT NULL, - end_week int(10) unsigned NOT NULL, - created_at datetime NOT NULL, - updated_at datetime NOT NULL, - PRIMARY KEY (id), - UNIQUE KEY (event_id, start_week), - KEY (term, start_week, end_week) -) ENGINE=InnoDB DEFAULT CHARSET=utf8; -` - - if _, err := conn.Exec(sql); err != nil { - return xerrors.Errorf("failed to create events table: %w", err) +func initEventTable(conn *sql.DB) error { + if err := initTable(conn, "events"); err != nil { + return xerrors.Errorf("failed to init events: %w", err) } id := 1 + adapter := driver.Adapter for eventID := 1; eventID <= 1000; eventID++ { startWeek := 1 endWeek := 12 term := "daytime" eventCategoryID := eventID for j := 0; j < 4; j++ { - if _, err := conn.Exec("insert into events values(?, ?, ?, ?, ?, ?, ?, ?)", id, eventID, eventCategoryID, term, startWeek, endWeek, time.Now(), time.Now()); err != nil { + qh := adapter.QueryHelper() + query := fmt.Sprintf("INSERT INTO %s values(%s)", qh.Quote("events"), qh.Placeholders(8)) + if _, err := conn.Exec(query, id, eventID, eventCategoryID, term, startWeek, endWeek, time.Now(), time.Now()); err != nil { return xerrors.Errorf("failed to insert into events table: %w", err) } id++ @@ -91,40 +117,23 @@ CREATE TABLE events ( endWeek += 12 } } - return nil } func initUserLoginTable(conn *sql.DB) error { - if _, err := conn.Exec("DROP TABLE IF EXISTS user_logins"); err != nil { - return xerrors.Errorf("failed to drop user_logins table: %w", err) + if err := initTable(conn, "user_logins"); err != nil { + return xerrors.Errorf("failed to exec user_logins: %w", err) } - sql := ` -CREATE TABLE IF NOT EXISTS user_logins ( - id bigint(20) unsigned NOT NULL AUTO_INCREMENT, - user_id bigint(20) unsigned NOT NULL, - user_session_id bigint(20) unsigned NOT NULL, - login_param_id bigint(20) unsigned NOT NULL, - name varchar(255) NOT NULL, - created_at datetime NOT NULL, - updated_at datetime NOT NULL, - PRIMARY KEY (id), - UNIQUE KEY (user_id, user_session_id), - KEY (user_id, login_param_id), - KEY (user_id, created_at) -) ENGINE=InnoDB DEFAULT CHARSET=utf8 -` - if _, err := conn.Exec(sql); err != nil { - return xerrors.Errorf("failed to create user_logins table: %w", err) - } - userID := 1 userSessionID := 1 loginParamID := 1 name := "rapidash1" + qh := driver.Adapter.QueryHelper() + columns := []string{qh.Quote("user_id"), qh.Quote("user_session_id"), qh.Quote("login_param_id"), qh.Quote("name"), qh.Quote("created_at"), qh.Quote("updated_at")} for ; userID <= 1000; userID++ { - if _, err := conn.Exec("INSERT INTO `user_logins` (`user_id`,`user_session_id`,`login_param_id`,`name`,`created_at`,`updated_at`) VALUES (?, ?, ?, ?, ?, ?)", - userID, userSessionID, loginParamID, name, time.Now(), time.Now()); err != nil { + qh := driver.Adapter.QueryHelper() + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qh.Quote("user_logins"), strings.Join(columns, ","), qh.Placeholders(6)) + if _, err := conn.Exec(query, userID, userSessionID, loginParamID, name, time.Now(), time.Now()); err != nil { return xerrors.Errorf("failed to insert into user_logins table: %w", err) } } @@ -132,39 +141,15 @@ CREATE TABLE IF NOT EXISTS user_logins ( } func initPtrTable(conn *sql.DB) error { - if _, err := conn.Exec("DROP TABLE IF EXISTS ptr"); err != nil { - return xerrors.Errorf("failed to drop ptr table: %w", err) - } - sql := ` -CREATE TABLE IF NOT EXISTS ptr ( - id bigint(20) unsigned NOT NULL AUTO_INCREMENT, - intptr int, - int8ptr int, - int16ptr int, - int32ptr int, - int64ptr int, - uintptr int unsigned, - uint8ptr int unsigned, - uint16ptr int unsigned, - uint32ptr int unsigned, - uint64ptr bigint unsigned, - float32ptr float, - float64ptr double, - boolptr tinyint, - bytesptr varchar(255), - stringptr varchar(255), - timeptr datetime, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8 -` - if _, err := conn.Exec(sql); err != nil { - return xerrors.Errorf("failed to create ptr table: %w", err) + if err := initTable(conn, "ptr"); err != nil { + return xerrors.Errorf("failed to exec ptr: %w", err) } - if _, err := conn.Exec("INSERT INTO `ptr` () values ()"); err != nil { + qh := driver.Adapter.QueryHelper() + if _, err := conn.Exec(fmt.Sprintf("INSERT INTO %s VALUES (DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT,DEFAULT)", qh.Quote("ptr"))); err != nil { return xerrors.Errorf("failed to insert empty record to ptr table: %w", err) } - if _, err := conn.Exec(` -INSERT INTO ptr + if _, err := conn.Exec(fmt.Sprintf(` +INSERT INTO %s ( intptr, int8ptr, @@ -184,35 +169,21 @@ INSERT INTO ptr timeptr ) values - (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NOW()) -`, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1.23, 4.56, true, "bytes", "string"); err != nil { + (%s) +`, qh.Quote("ptr"), qh.Placeholders(16)), 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1.23, 4.56, true, "bytes", "string", time.Now()); err != nil { return xerrors.Errorf("failed to insert default value to ptr table: %w", err) } return nil } func initUserLogTable(conn *sql.DB) error { - if _, err := conn.Exec("DROP TABLE IF EXISTS user_logs"); err != nil { - return xerrors.Errorf("failed to drop user_logs table: %w", err) + if err := initTable(conn, "user_logs"); err != nil { + return xerrors.Errorf("failed to exec user_logs: %w", err) } - sql := ` -CREATE TABLE IF NOT EXISTS user_logs ( - id bigint(20) unsigned NOT NULL AUTO_INCREMENT, - user_id bigint(20) unsigned NOT NULL, - content_type varchar(255) NOT NULL, - content_id bigint(20) unsigned NOT NULL, - created_at datetime NOT NULL, - updated_at datetime NOT NULL, - PRIMARY KEY (id), - KEY (user_id, created_at), - KEY (user_id, content_type, content_id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8 -` - if _, err := conn.Exec(sql); err != nil { - return xerrors.Errorf("failed to create user_logs table: %w", err) - } - - if _, err := conn.Exec("INSERT INTO `user_logs` (`user_id`,`content_type`,`content_id`,`created_at`,`updated_at`) VALUES (?, ?, ?, ?, ?)", 1, "rapidash", 1, time.Now(), time.Now()); err != nil { + qh := driver.Adapter.QueryHelper() + columns := []string{qh.Quote("user_id"), qh.Quote("content_type"), qh.Quote("content_id"), qh.Quote("created_at"), qh.Quote("updated_at")} + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", qh.Quote("user_logs"), strings.Join(columns, ","), qh.Placeholders(5)) + if _, err := conn.Exec(query, 1, "rapidash", 1, time.Now(), time.Now()); err != nil { return xerrors.Errorf("failed to insert into user_logs table: %w", err) } @@ -220,21 +191,9 @@ CREATE TABLE IF NOT EXISTS user_logs ( } func initEmptyTable(conn *sql.DB) error { - if _, err := conn.Exec("DROP TABLE IF EXISTS empties"); err != nil { - return xerrors.Errorf("failed to drop empties table: %w", err) + if err := initTable(conn, "empties"); err != nil { + return xerrors.Errorf("failed to exec empties: %w", err) } - - sql := ` -CREATE TABLE empties ( - id bigint(20) unsigned NOT NULL, - PRIMARY KEY (id) -) ENGINE=InnoDB DEFAULT CHARSET=utf8; -` - - if _, err := conn.Exec(sql); err != nil { - return xerrors.Errorf("failed to create empties table: %w", err) - } - return nil } @@ -253,6 +212,7 @@ func initCache(conn *sql.DB, typ CacheServerType) error { ServerAddrs(serverAddrs), LogMode(LogModeJSON), LogEnabled(true), + DatabaseAdapter(driver.DBType), ) if err != nil { return xerrors.Errorf("failed to create rapidash instance: %w", err) @@ -299,7 +259,7 @@ func initCache(conn *sql.DB, typ CacheServerType) error { func TestMain(m *testing.M) { var err error - conn, err = sql.Open("mysql", "root:@tcp(localhost:3306)/rapidash?parseTime=true") + conn, err = sql.Open(driver.Name, driver.Source) if err != nil { panic(err) } @@ -308,7 +268,6 @@ func TestMain(m *testing.M) { } result := m.Run() - os.Exit(result) } @@ -389,11 +348,11 @@ func TestTx_CreateByTableContext(t *testing.T) { NoError(t, err) NotEqualf(t, id, 0, "last insert id is zero") var findUserFromSLCByPrimaryKey UserLogin - builder := NewQueryBuilder("user_logins").Eq("id", uint64(0)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(0)) NoError(t, tx.FindByQueryBuilder(builder, &findUserFromSLCByPrimaryKey)) Equal(t, findUserFromSLCByPrimaryKey.ID, userLogin.ID) var findUserFromSLCByUniqueKey UserLogin - builder = NewQueryBuilder("user_logins").Eq("user_id", uint64(0)).Eq("user_session_id", uint64(1000)) + builder = NewQueryBuilder("user_logins", driver.Adapter).Eq("user_id", uint64(0)).Eq("user_session_id", uint64(1000)) NoError(t, tx.FindByQueryBuilder(builder, &findUserFromSLCByUniqueKey)) Equal(t, findUserFromSLCByPrimaryKey.ID, userLogin.ID) NoError(t, tx.Commit()) @@ -417,7 +376,7 @@ func TestTx_FindByQueryBuilderContext(t *testing.T) { defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() NoError(t, tx.Commit()) - builder := NewQueryBuilder("events") + builder := NewQueryBuilder("events", driver.Adapter) var events EventSlice if err := tx.FindByQueryBuilderContext(context.Background(), builder, &events); err != nil { if !xerrors.Is(err, ErrAlreadyCommittedTransaction) { @@ -432,7 +391,7 @@ func TestTx_FindByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("events") + builder := NewQueryBuilder("events", driver.Adapter) var events EventSlice NoError(t, tx.FindByQueryBuilderContext(context.Background(), builder, &events)) NoError(t, tx.Commit()) @@ -442,7 +401,7 @@ func TestTx_FindByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) var userLogin UserLogin NoError(t, tx.FindByQueryBuilderContext(context.Background(), builder, &userLogin)) NoError(t, tx.Commit()) @@ -452,7 +411,7 @@ func TestTx_FindByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logins") + builder := NewQueryBuilder("user_logins", driver.Adapter) var userLogins UserLogins if err := tx.FindByQueryBuilderContext(context.Background(), builder, &userLogins); err != nil { if !xerrors.Is(err, ErrConnectionOfTransaction) { @@ -467,7 +426,7 @@ func TestTx_FindByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("users") + builder := NewQueryBuilder("users", driver.Adapter) var userLogins UserLogins if err := tx.FindByQueryBuilderContext(context.Background(), builder, &userLogins); err == nil { t.Fatal("err is nil\n") @@ -477,7 +436,7 @@ func TestTx_FindByQueryBuilderContext(t *testing.T) { tx, err := cache.Begin(conn) NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logs").Eq("id", uint64(1)).Gte("content_id", uint64(1)).Lte("content_id", uint64(1)) + builder := NewQueryBuilder("user_logs", driver.Adapter).Eq("id", uint64(1)).Gte("content_id", uint64(1)).Lte("content_id", uint64(1)) var userLogs UserLogs NoError(t, tx.FindByQueryBuilderContext(context.Background(), builder, &userLogs)) NoError(t, tx.Commit()) @@ -495,7 +454,7 @@ func TestTx_CountByQueryBuilder(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("events") + builder := NewQueryBuilder("events", driver.Adapter) count, err := tx.CountByQueryBuilder(builder) NoError(t, err) NotEqualf(t, count, 0, "failed count") @@ -506,7 +465,7 @@ func TestTx_CountByQueryBuilder(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logins") + builder := NewQueryBuilder("user_logins", driver.Adapter) count, err := tx.CountByQueryBuilder(builder) NoError(t, err) NotEqualf(t, count, 0, "failed count") @@ -517,7 +476,7 @@ func TestTx_CountByQueryBuilder(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("unknown") + builder := NewQueryBuilder("unknown", driver.Adapter) if _, err := tx.CountByQueryBuilder(builder); err == nil { t.Fatal("err is nil") } @@ -530,7 +489,7 @@ func TestTx_FindAllByTable(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("events") + builder := NewQueryBuilder("events", driver.Adapter) count, err := tx.CountByQueryBuilder(builder) NoError(t, err) var events EventSlice @@ -538,7 +497,7 @@ func TestTx_FindAllByTable(t *testing.T) { Equalf(t, len(events), int(count), "invalid events length") - builder = NewQueryBuilder("user_logins") + builder = NewQueryBuilder("user_logins", driver.Adapter) count, err = tx.CountByQueryBuilder(builder) NoError(t, err) var userLogins UserLogins @@ -565,14 +524,14 @@ func TestTx_UpdateByQueryBuilder(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - findBuilder := NewQueryBuilder("user_logins"). + findBuilder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("user_session_id", uint64(1)) var userLogin UserLogin NoError(t, tx.FindByQueryBuilder(findBuilder, &userLogin)) NotEqualf(t, userLogin.ID, 0, "cannot find userLogin") - builder := NewQueryBuilder("user_logins").Eq("id", userLogin.ID) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", userLogin.ID) NoError(t, tx.UpdateByQueryBuilder(builder, map[string]interface{}{ "login_param_id": uint64(10), })) @@ -588,7 +547,7 @@ func TestTx_UpdateByQueryBuilderContext(t *testing.T) { defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() NoError(t, tx.Commit()) - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) if err := tx.UpdateByQueryBuilderContext(context.Background(), builder, map[string]interface{}{ "login_param_id": uint64(10), }); err != nil { @@ -606,7 +565,7 @@ func TestTx_UpdateByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("events").Eq("id", uint64(1)) + builder := NewQueryBuilder("events", driver.Adapter).Eq("id", uint64(1)) var event Event NoError(t, tx.FindByQueryBuilder(builder, &event)) NotEqualf(t, event.ID, 0, "cannot find event") @@ -622,7 +581,7 @@ func TestTx_UpdateByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) if err := tx.UpdateByQueryBuilderContext(context.Background(), builder, map[string]interface{}{ "login_param_id": uint64(10), }); err != nil { @@ -640,14 +599,14 @@ func TestTx_UpdateByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - findBuilder := NewQueryBuilder("user_logins"). + findBuilder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("user_session_id", uint64(1)) var userLogin UserLogin NoError(t, tx.FindByQueryBuilder(findBuilder, &userLogin)) NotEqualf(t, userLogin.ID, 0, "cannot find userLogin") - builder := NewQueryBuilder("user_logins").Eq("id", userLogin.ID) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", userLogin.ID) NoError(t, tx.UpdateByQueryBuilderContext(context.Background(), builder, map[string]interface{}{ "login_param_id": uint64(10), })) @@ -660,7 +619,7 @@ func TestTx_UpdateByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("rapidash").Eq("id", uint64(1)) + builder := NewQueryBuilder("rapidash", driver.Adapter).Eq("id", uint64(1)) if err := tx.UpdateByQueryBuilderContext(context.Background(), builder, map[string]interface{}{ "start_week": uint8(10), }); err == nil { @@ -676,14 +635,14 @@ func TestTx_DeleteByQueryBuilder(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - findBuilder := NewQueryBuilder("user_logins"). + findBuilder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("user_session_id", uint64(1)) var userLogin UserLogin NoError(t, tx.FindByQueryBuilder(findBuilder, &userLogin)) NotEqualf(t, userLogin.ID, 0, "cannot find userLogin") - builder := NewQueryBuilder("user_logins").Eq("id", userLogin.ID) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", userLogin.ID) NoError(t, tx.DeleteByQueryBuilder(builder)) NoError(t, tx.Commit()) } @@ -697,7 +656,7 @@ func TestTx_DeleteByQueryBuilderContext(t *testing.T) { defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() NoError(t, tx.Commit()) - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) if err := tx.DeleteByQueryBuilderContext(context.Background(), builder); err != nil { if !xerrors.Is(err, ErrAlreadyCommittedTransaction) { t.Fatalf("unexpected type err: %+v", err) @@ -713,7 +672,7 @@ func TestTx_DeleteByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("events").Eq("id", uint64(1)) + builder := NewQueryBuilder("events", driver.Adapter).Eq("id", uint64(1)) var event Event NoError(t, tx.FindByQueryBuilder(builder, &event)) NotEqualf(t, event.ID, 0, "cannot find event") @@ -727,7 +686,7 @@ func TestTx_DeleteByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("user_logins").Eq("id", uint64(1)) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", uint64(1)) if err := tx.DeleteByQueryBuilderContext(context.Background(), builder); err != nil { if !xerrors.Is(err, ErrConnectionOfTransaction) { t.Fatalf("unexpected type err: %+v", err) @@ -743,14 +702,14 @@ func TestTx_DeleteByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - findBuilder := NewQueryBuilder("user_logins"). + findBuilder := NewQueryBuilder("user_logins", driver.Adapter). Eq("user_id", uint64(1)). Eq("user_session_id", uint64(1)) var userLogin UserLogin NoError(t, tx.FindByQueryBuilder(findBuilder, &userLogin)) NotEqualf(t, userLogin.ID, 0, "cannot find userLogin") - builder := NewQueryBuilder("user_logins").Eq("id", userLogin.ID) + builder := NewQueryBuilder("user_logins", driver.Adapter).Eq("id", userLogin.ID) NoError(t, tx.DeleteByQueryBuilderContext(context.Background(), builder)) NoError(t, tx.Commit()) }) @@ -761,7 +720,7 @@ func TestTx_DeleteByQueryBuilderContext(t *testing.T) { NoError(t, err) defer func() { NoError(t, tx.RollbackUnlessCommitted()) }() - builder := NewQueryBuilder("rapidash").Eq("id", uint64(1)) + builder := NewQueryBuilder("rapidash", driver.Adapter).Eq("id", uint64(1)) if err := tx.DeleteByQueryBuilderContext(context.Background(), builder); err == nil { t.Fatalf("err is nil") }