Skip to content

Commit

Permalink
Merge pull request #25 from go-gorm/dev
Browse files Browse the repository at this point in the history
feat: db name and distinct method
  • Loading branch information
tr1v3r authored Aug 9, 2021
2 parents 5c71561 + 08f84bb commit 59b492d
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 39 deletions.
49 changes: 35 additions & 14 deletions do.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,23 @@ var (
// // withSELECT add SELECT clause
// withSELECT stmtOpt = func(stmt *gorm.Statement) *gorm.Statement {
// if _, ok := stmt.Clauses["SELECT"]; !ok {
// stmt.AddClause(clause.Select{})
// stmt.AddClause(clause.Select{Distinct: stmt.Distinct})
// }
// return stmt
// }
)

// buildStmt call statement.Build to combine all clauses in one statement
func (d *DO) buildStmt(opts ...stmtOpt) *gorm.Statement {
// build FOR TEST. call statement.Build to combine all clauses in one statement
func (d *DO) build(opts ...stmtOpt) *gorm.Statement {
stmt := d.db.Statement
for _, opt := range opts {
stmt = opt(stmt)
}

if _, ok := stmt.Clauses["SELECT"]; !ok && len(stmt.Selects) > 0 {
stmt.AddClause(clause.Select{Distinct: stmt.Distinct, Expression: clause.Expr{SQL: strings.Join(stmt.Selects, ",")}})
}

findClauses := func() []string {
for _, cs := range [][]string{createClauses, queryClauses, updateClauses, deleteClauses} {
if _, ok := stmt.Clauses[cs[0]]; ok {
Expand All @@ -123,11 +127,6 @@ func (d *DO) buildStmt(opts ...stmtOpt) *gorm.Statement {
return stmt
}

// func (s *DO) subQueryExpr() clause.Expr {
// stmt := s.buildStmt(withFROM, withSELECT)
// return clause.Expr{SQL: "(" + stmt.SQL.String() + ")", Vars: stmt.Vars}
// }

// Debug return a DO with db in debug mode
func (d *DO) Debug() Dao {
return NewDO(d.db.Debug())
Expand All @@ -152,7 +151,7 @@ func (d *DO) Select(columns ...field.Expr) Dao {
if len(columns) == 0 {
return NewDO(d.db.Clauses(clause.Select{}))
}
return NewDO(d.db.Clauses(clause.Select{Expression: clause.CommaExpression{Exprs: toExpression(columns...)}}))
return NewDO(d.db.Select(buildExpr(d.db.Statement, columns...)))
}

func (d *DO) Where(conds ...Condition) Dao {
Expand All @@ -176,7 +175,7 @@ func (d *DO) Order(columns ...field.Expr) Dao {

func (d *DO) Distinct(columns ...field.Expr) Dao {
Emit(methodDistinct)
return NewDO(d.db.Distinct(toInterfaceSlice(toColNames(d.db.Statement, columns...))...))
return NewDO(d.db.Distinct(toInterfaceSlice(toColumnFullName(d.db.Statement, columns...))...))
}

func (d *DO) Omit(columns ...field.Expr) Dao {
Expand Down Expand Up @@ -441,12 +440,28 @@ func condToExpression(conds ...Condition) []clause.Expression {
return exprs
}

func toColumnFullName(stmt *gorm.Statement, columns ...field.Expr) []string {
return buildColumn(stmt, columns, field.WithAll)
}

func toColNames(stmt *gorm.Statement, columns ...field.Expr) []string {
names := make([]string, len(columns))
for i, col := range columns {
names[i] = col.BuildColumn(stmt)
return buildColumn(stmt, columns)
}

func buildColumn(stmt *gorm.Statement, cols []field.Expr, opts ...field.BuildOpt) []string {
results := make([]string, len(cols))
for i, c := range cols {
results[i] = c.BuildColumn(stmt, opts...)
}
return names
return results
}

func buildExpr(stmt *gorm.Statement, exprs ...field.Expr) []string {
results := make([]string, len(exprs))
for i, e := range exprs {
results[i] = e.BuildExpr(stmt)
}
return results
}

func toInterfaceSlice(value interface{}) []interface{} {
Expand All @@ -459,6 +474,12 @@ func toInterfaceSlice(value interface{}) []interface{} {
res[i] = item
}
return res
case []clause.Column:
res := make([]interface{}, len(v))
for i, item := range v {
res[i] = item
}
return res
default:
return nil
}
Expand Down
54 changes: 39 additions & 15 deletions do_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
var db, _ = gorm.Open(tests.DummyDialector{}, nil)

func init() {
db = db.Debug()

callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "ORDER BY", "LIMIT"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "ORDER BY", "LIMIT"},
Expand Down Expand Up @@ -71,7 +73,7 @@ type User struct {
RegisterAt field.Time
}

var u = func() User {
var u = func() *User {
u := User{
ID: field.NewUint("", "id"),
Name: field.NewString("", "name"),
Expand All @@ -83,7 +85,7 @@ var u = func() User {
}
u.UseDB(db.Session(&gorm.Session{DryRun: true}))
u.UseModel(UserRaw{})
return u
return &u
}()

type Student struct {
Expand Down Expand Up @@ -125,7 +127,7 @@ var teacher = func() *Teacher {
}()

func checkBuildExpr(t *testing.T, e Dao, opts []stmtOpt, result string, vars []interface{}) {
stmt := e.(*DO).buildStmt(opts...)
stmt := e.(*DO).build(opts...)

sql := strings.TrimSpace(stmt.SQL.String())
if sql != result {
Expand All @@ -145,14 +147,36 @@ func TestDO_methods(t *testing.T) {
Result string
}{
{
Expr: u.Select(),
ExpectedVars: nil,
Result: "SELECT *",
Expr: u.Select(),
Result: "SELECT *",
},
{
Expr: u.Select(u.ID, u.Name),
ExpectedVars: nil,
Result: "SELECT `id`, `name`",
Expr: u.Select(u.ID, u.Name),
Result: "SELECT `id`,`name`",
},
{
Expr: u.Distinct(u.Name),
Result: "SELECT DISTINCT `name`",
},
{
Expr: teacher.Distinct(teacher.ID, teacher.Name),
Result: "SELECT DISTINCT `teacher`.`id`,`teacher`.`name`",
},
{
Expr: teacher.Select(teacher.ID, teacher.Name).Distinct(),
Result: "SELECT DISTINCT `teacher`.`id`,`teacher`.`name`",
},
{
Expr: teacher.Distinct().Select(teacher.ID, teacher.Name),
Result: "SELECT DISTINCT `teacher`.`id`,`teacher`.`name`",
},
{
Expr: teacher.Select(teacher.Name.As("n")).Distinct(),
Result: "SELECT DISTINCT `teacher`.`name` AS `n`",
},
{
Expr: teacher.Select(teacher.ID.As("i"), teacher.Name.As("n")).Distinct(),
Result: "SELECT DISTINCT `teacher`.`id` AS `i`,`teacher`.`name` AS `n`",
},
{
Expr: u.Where(u.ID.Eq(10)),
Expand Down Expand Up @@ -220,7 +244,7 @@ func TestDO_methods(t *testing.T) {
{
Expr: u.Select(u.ID, u.Name).Where(u.Age.Gt(18), u.Score.Gte(100)),
ExpectedVars: []interface{}{18, 100.0},
Result: "SELECT `id`, `name` WHERE `age` > ? AND `score` >= ?",
Result: "SELECT `id`,`name` WHERE `age` > ? AND `score` >= ?",
},
// ======================== subquery ========================
{
Expand All @@ -236,7 +260,7 @@ func TestDO_methods(t *testing.T) {
{
Expr: u.Select(u.ID, u.Name).Where(Lte(u.Score, u.Select(u.Score.Avg()).Where(u.Age.Gte(18)))),
ExpectedVars: []interface{}{18},
Result: "SELECT `id`, `name` WHERE `score` <= (SELECT AVG(`score`) FROM `users_info` WHERE `age` >= ?)",
Result: "SELECT `id`,`name` WHERE `score` <= (SELECT AVG(`score`) FROM `users_info` WHERE `age` >= ?)",
},
{
Expr: u.Select(u.ID).Where(In(u.Score, u.Select(u.Score).Where(u.Age.Gte(18)))),
Expand All @@ -246,7 +270,7 @@ func TestDO_methods(t *testing.T) {
{
Expr: u.Select(u.ID).Where(In(u.ID, u.Age, u.Select(u.ID, u.Age).Where(u.Score.Eq(100)))),
ExpectedVars: []interface{}{100.0},
Result: "SELECT `id` WHERE (`id`, `age`) IN (SELECT `id`, `age` FROM `users_info` WHERE `score` = ?)",
Result: "SELECT `id` WHERE (`id`, `age`) IN (SELECT `id`,`age` FROM `users_info` WHERE `score` = ?)",
},
{
Expr: u.Select(u.Age.Avg().As("avgage")).Group(u.Name).Having(Gt(u.Age.Avg(), u.Select(u.Age.Avg()).Where(u.Name.Like("name%")))),
Expand All @@ -259,7 +283,7 @@ func TestDO_methods(t *testing.T) {
Expr: Table(u.Select(u.ID, u.Name).Where(u.Age.Gt(18))).Select(),
Opts: []stmtOpt{withFROM},
ExpectedVars: []interface{}{18},
Result: "SELECT * FROM (SELECT `id`, `name` FROM `users_info` WHERE `age` > ?)",
Result: "SELECT * FROM (SELECT `id`,`name` FROM `users_info` WHERE `age` > ?)",
},
{
Expr: Table(u.Select(u.ID).Where(u.Age.Gt(18)), u.Select(u.ID).Where(u.Score.Gte(100))).Select(),
Expand Down Expand Up @@ -287,7 +311,7 @@ func TestDO_methods(t *testing.T) {
},
{
Expr: student.LeftJoin(teacher, student.Instructor.EqCol(teacher.ID)).Where(teacher.ID.Gt(0)).Select(student.Name, teacher.Name),
Result: "SELECT `student`.`name`, `teacher`.`name` FROM `student` LEFT JOIN `teacher` ON `student`.`instructor` = `teacher`.`id` WHERE `teacher`.`id` > ?",
Result: "SELECT `student`.`name`,`teacher`.`name` FROM `student` LEFT JOIN `teacher` ON `student`.`instructor` = `teacher`.`id` WHERE `teacher`.`id` > ?",
ExpectedVars: []interface{}{int64(0)},
},
{
Expand All @@ -297,7 +321,7 @@ func TestDO_methods(t *testing.T) {
},
{
Expr: student.Join(teacher, student.Instructor.EqCol(teacher.ID)).LeftJoin(teacher, student.ID.Gt(100)).Select(student.ID, student.Name, teacher.Name.As("teacher_name")),
Result: "SELECT `student`.`id`, `student`.`name`, `teacher`.`name` AS `teacher_name` FROM `student` INNER JOIN `teacher` ON `student`.`instructor` = `teacher`.`id` LEFT JOIN `teacher` ON `student`.`id` > ?",
Result: "SELECT `student`.`id`,`student`.`name`,`teacher`.`name` AS `teacher_name` FROM `student` INNER JOIN `teacher` ON `student`.`instructor` = `teacher`.`id` LEFT JOIN `teacher` ON `student`.`id` > ?",
ExpectedVars: []interface{}{int64(100)},
},
}
Expand Down
10 changes: 10 additions & 0 deletions field/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Expr interface {
Column() clause.Column
BuildColumn(*gorm.Statement, ...BuildOpt) string
RawExpr() interface{}
BuildExpr(stmt *gorm.Statement) string

// pirvate do nothing, prevent users from implementing interfaces outside the package
private()
Expand Down Expand Up @@ -86,6 +87,15 @@ func (e expr) BuildColumn(stmt *gorm.Statement, opts ...BuildOpt) string {
return stmt.Quote(col)
}

func (e expr) BuildExpr(stmt *gorm.Statement) string {
if e.expression == nil {
return e.BuildColumn(stmt, WithAll)
}
newStmt := &gorm.Statement{DB: stmt.DB, Table: stmt.Table, Schema: stmt.Schema}
e.expression.Build(newStmt)
return newStmt.SQL.String()
}

func (e expr) RawExpr() interface{} {
if e.expression == nil {
return e.Col
Expand Down
13 changes: 12 additions & 1 deletion generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ type Config struct {
ModelPkgName string // generated model code's package name

queryPkgName string // generated query code's package name
dbNameOpts []check.SchemaNameOpt
}

// WithDbNameOpts set get database name function
func (cfg *Config) WithDbNameOpts(opts ...check.SchemaNameOpt) {
if cfg.dbNameOpts == nil {
cfg.dbNameOpts = make([]check.SchemaNameOpt, 0, len(opts))
}
for _, opt := range opts {
cfg.dbNameOpts = append(cfg.dbNameOpts, opt)
}
}

// genInfo info about generated code
Expand Down Expand Up @@ -80,7 +91,7 @@ func (g *Generator) UseDB(db *gorm.DB) {

// GenerateModel catch table info from db, return a BaseStruct
func (g *Generator) GenerateModel(tableName string, modelName string) *check.BaseStruct {
s, err := check.GenBaseStructs(g.db, g.Config.ModelPkgName, tableName, modelName)
s, err := check.GenBaseStructs(g.db, g.Config.ModelPkgName, tableName, modelName, g.dbNameOpts...)
if err != nil {
log.Fatalf("check struct error: %s", err)
}
Expand Down
2 changes: 1 addition & 1 deletion interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type Condition interface {

type subQuery interface {
UnderlyingDB() *gorm.DB
buildStmt(opts ...stmtOpt) *gorm.Statement
build(opts ...stmtOpt) *gorm.Statement
}

// Dao CRUD methods
Expand Down
24 changes: 18 additions & 6 deletions internal/check/gen_structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ var dataType = map[string]string{
"integer": "int32",
}

type SchemaNameOpt func(db *gorm.DB) string

// GenBaseStructs generate db model by table name
func GenBaseStructs(db *gorm.DB, pkg string, tableName, modelName string) (bases *BaseStruct, err error) {
func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt ...SchemaNameOpt) (bases *BaseStruct, err error) {
if isDBUndefined(db) {
return nil, fmt.Errorf("gen config db is undefined")
}
Expand All @@ -68,7 +70,7 @@ func GenBaseStructs(db *gorm.DB, pkg string, tableName, modelName string) (bases
pkg = ModelPkg
}
singular := singularModel(db.Config)
dbName := getSchemaName(db)
dbName := getSchemaName(db, schemaNameOpt...)
columns, err := getTbColumns(db, dbName, tableName)
if err != nil {
return nil, err
Expand Down Expand Up @@ -106,9 +108,15 @@ func getTbColumns(db *gorm.DB, schemaName string, tableName string) (result []*C
}

// get mysql db' name
var dbNameReg = regexp.MustCompile(`/\w+\?`)

func getSchemaName(db *gorm.DB) string {
var dbNameReg = regexp.MustCompile(`/\w+\??`)

func getSchemaName(db *gorm.DB, opts ...SchemaNameOpt, ) string {
for _, opt := range opts {
name := opt(db)
if name != "" {
return name
}
}
if db == nil || db.Dialector == nil {
return ""
}
Expand All @@ -120,7 +128,11 @@ func getSchemaName(db *gorm.DB) string {
if len(dbName) < 3 {
return ""
}
return dbName[1 : len(dbName)-1]
end := len(dbName)
if strings.HasSuffix(dbName, "?") {
end--
}
return dbName[1:end]
}

// convert Table name or column name to camel case
Expand Down
4 changes: 2 additions & 2 deletions internal/template/tmpl.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 59b492d

Please sign in to comment.