Skip to content

Commit

Permalink
style(generate): format
Browse files Browse the repository at this point in the history
  • Loading branch information
tr1v3r committed Dec 10, 2021
1 parent 39845e5 commit fe851bf
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 117 deletions.
93 changes: 93 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package gen

import (
"fmt"
"path/filepath"
"strings"

"gorm.io/gen/internal/check"
"gorm.io/gen/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/utils/tests"
)

type GenerateMode uint

const (
// WithDefaultQuery create default query in generated code
WithDefaultQuery GenerateMode = 1 << iota

// WithoutContext generate code without context constrain
WithoutContext
)

// Config generator's basic configuration
type Config struct {
db *gorm.DB // db connection

OutPath string // query code path
OutFile string // query code file name, default: gen.go
ModelPkgPath string // generated model code's package name
WithUnitTest bool // generate unit test for query code

// generate model global configuration
FieldNullable bool // generate pointer when field is nullable
FieldWithIndexTag bool // generate with gorm index tag
FieldWithTypeTag bool // generate with gorm column type tag

Mode GenerateMode // generate mode

queryPkgName string // generated query code's package name
dbNameOpts []model.SchemaNameOpt
dataTypeMap map[string]func(detailType string) (dataType string)
fieldJSONTagNS func(columnName string) string
fieldNewTagNS func(columnName string) string
}

// WithDbNameOpts set get database name function
func (cfg *Config) WithDbNameOpts(opts ...model.SchemaNameOpt) {
if cfg.dbNameOpts == nil {
cfg.dbNameOpts = opts
} else {
cfg.dbNameOpts = append(cfg.dbNameOpts, opts...)
}
}

func (cfg *Config) WithDataTypeMap(newMap map[string]func(detailType string) (dataType string)) {
cfg.dataTypeMap = newMap
}

func (cfg *Config) WithJSONTagNameStrategy(ns func(columnName string) (tagContent string)) {
cfg.fieldJSONTagNS = ns
}

func (cfg *Config) WithNewTagNameStrategy(ns func(columnName string) (tagContent string)) {
cfg.fieldNewTagNS = ns
}

// Revise format path and db
func (cfg *Config) Revise() (err error) {
if strings.TrimSpace(cfg.ModelPkgPath) == "" {
cfg.ModelPkgPath = check.DefaultModelPkg
}

cfg.OutPath, err = filepath.Abs(cfg.OutPath)
if err != nil {
return fmt.Errorf("outpath is invalid: %w", err)
}
if cfg.OutPath == "" {
cfg.OutPath = "./query/"
}
if cfg.OutFile == "" {
cfg.OutFile = cfg.OutPath + "/gen.go"
}
cfg.queryPkgName = filepath.Base(cfg.OutPath)

if cfg.db == nil {
cfg.db, _ = gorm.Open(tests.DummyDialector{})
}

return nil
}

func (cfg *Config) judgeMode(mode GenerateMode) bool { return cfg.Mode&mode != 0 }
153 changes: 36 additions & 117 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@ import (
"text/template"

"golang.org/x/tools/imports"

"gorm.io/gorm"

"gorm.io/gen/internal/check"
"gorm.io/gen/internal/model"
"gorm.io/gen/internal/parser"
tmpl "gorm.io/gen/internal/template"
"gorm.io/gorm"
"gorm.io/gorm/utils/tests"
)

// TODO implement some unit tests

// T generic type
type T interface{}

Expand All @@ -46,79 +45,6 @@ func NewGenerator(cfg Config) *Generator {
}
}

type GenerateMode uint

const (
// WithDefaultQuery create default query in generated code
WithDefaultQuery GenerateMode = 1 << iota

// WithoutContext generate code without context constrain
WithoutContext
)

// Config generator's basic configuration
type Config struct {
db *gorm.DB // db connection

OutPath string // query code path
OutFile string // query code file name, default: gen.go
ModelPkgPath string // generated model code's package name
WithUnitTest bool // generate unit test for query code

// generate model global configuration
FieldNullable bool // generate pointer when field is nullable
FieldWithIndexTag bool // generate with gorm index tag
FieldWithTypeTag bool // generate with gorm column type tag

Mode GenerateMode // generate mode

queryPkgName string // generated query code's package name
dbNameOpts []model.SchemaNameOpt
dataTypeMap map[string]func(detailType string) (dataType string)
fieldJSONTagNS func(columnName string) string
fieldNewTagNS func(columnName string) string
}

// WithDbNameOpts set get database name function
func (cfg *Config) WithDbNameOpts(opts ...model.SchemaNameOpt) {
if cfg.dbNameOpts == nil {
cfg.dbNameOpts = opts
} else {
cfg.dbNameOpts = append(cfg.dbNameOpts, opts...)
}
}

func (cfg *Config) WithDataTypeMap(newMap map[string]func(detailType string) (dataType string)) {
cfg.dataTypeMap = newMap
}

func (cfg *Config) WithJSONTagNameStrategy(ns func(columnName string) (tagContent string)) {
cfg.fieldJSONTagNS = ns
}

func (cfg *Config) WithNewTagNameStrategy(ns func(columnName string) (tagContent string)) {
cfg.fieldNewTagNS = ns
}

func (cfg *Config) Revise() (err error) {
if cfg.ModelPkgPath == "" {
cfg.ModelPkgPath = check.DefaultModelPkg
}

cfg.OutPath, err = filepath.Abs(cfg.OutPath)
if err != nil {
return fmt.Errorf("outpath is invalid: %w", err)
}

if cfg.db == nil {
cfg.db, _ = gorm.Open(tests.DummyDialector{})
}

return nil
}

func (cfg *Config) judgeMode(mode GenerateMode) bool { return cfg.Mode&mode != 0 }

// genInfo info about generated code
type genInfo struct {
*check.BaseStruct
Expand Down Expand Up @@ -189,9 +115,10 @@ func (g *Generator) GenerateModelAs(tableName string, modelName string, fieldOpt
})
if err != nil {
g.db.Logger.Error(context.Background(), "generate struct from table fail: %s", err)
panic(fmt.Sprintf("generate struct fail: %s", err))
panic("generate struct fail")
}
g.modelData[s.StructName] = s

g.successInfo(fmt.Sprintf("got %d columns from table <%s>", len(s.Members), s.TableName))
return s
}
Expand Down Expand Up @@ -268,30 +195,19 @@ func (g *Generator) apply(fc interface{}, structs []*check.BaseStruct) {

// Execute generate code to output path
func (g *Generator) Execute() {
var err error

g.successInfo("Start generating code.")

if g.OutPath == "" {
g.OutPath = "./query/"
}
if g.OutFile == "" {
g.OutFile = g.OutPath + "/gen.go"
}
if err := os.MkdirAll(g.OutPath, os.ModePerm); err != nil {
g.db.Logger.Error(context.Background(), "create outpath(%s) fail: %s", g.OutPath, err)
panic("create outpath fail")
}
g.queryPkgName = filepath.Base(g.OutPath)

err = g.generateModelFile()
if err != nil {
g.db.Logger.Error(context.Background(), "generate basic struct from table fail: %s", err)
panic("generate basic struct from table fail")
if err := g.generateModelFile(); err != nil {
g.db.Logger.Error(context.Background(), "generate model struct fail: %s", err)
panic("generate model struct fail")
}

err = g.generateQueryFile()
if err != nil {
if err := g.generateQueryFile(); err != nil {
g.db.Logger.Error(context.Background(), "generate query code fail: %s", err)
panic("generate query code fail")
}
Expand All @@ -309,9 +225,13 @@ func (g *Generator) successInfo(logInfos ...string) {

// generateQueryFile generate query code and save to file
func (g *Generator) generateQueryFile() (err error) {
if len(g.Data) == 0 {
return nil
}

// generate query code for all struct
for _, info := range g.Data {
err = g.generateSubQuery(info)
err = g.generateSingleQueryFile(info)
if err != nil {
return err
}
Expand Down Expand Up @@ -344,6 +264,7 @@ func (g *Generator) generateQueryFile() (err error) {
if err != nil {
return err
}

err = g.output(g.OutFile, buf.Bytes())
if err != nil {
return err
Expand Down Expand Up @@ -376,8 +297,8 @@ func (g *Generator) generateQueryFile() (err error) {
return nil
}

// generateSubQuery generate query code and save to file
func (g *Generator) generateSubQuery(data *genInfo) (err error) {
// generateSingleQueryFile generate query code and save to file
func (g *Generator) generateSingleQueryFile(data *genInfo) (err error) {
var buf bytes.Buffer

err = render(tmpl.Header, &buf, map[string]string{
Expand Down Expand Up @@ -440,29 +361,14 @@ func (g *Generator) generateQueryUnitTestFile(data *genInfo) (err error) {
}

// generateModelFile generate model structures and save to file
func (g *Generator) generateModelFile() (err error) {
var outPath string
outPath, err = filepath.Abs(g.OutPath)
func (g *Generator) generateModelFile() error {
modelOutPath, err := g.getModelOutputPath()
if err != nil {
return err
}
path := filepath.Clean(g.ModelPkgPath)
if path == "" {
path = check.DefaultModelPkg
}
if strings.Contains(path, "/") {
outPath, err = filepath.Abs(path)
if err != nil {
return fmt.Errorf("cannot parse model pkg path: %w", err)
}
outPath += "/"
} else {
outPath = fmt.Sprint(filepath.Dir(outPath), "/", path, "/")
}

if err := os.MkdirAll(outPath, os.ModePerm); err != nil {
g.db.Logger.Error(context.Background(), "create model pkg path(%s) fail: %s", outPath, err)
panic("create model pkg path fail")
if err := os.MkdirAll(modelOutPath, os.ModePerm); err != nil {
return fmt.Errorf("create model pkg path(%s) fail: %s", modelOutPath, err)
}

for _, data := range g.modelData {
Expand All @@ -475,7 +381,8 @@ func (g *Generator) generateModelFile() (err error) {
if err != nil {
return err
}
modelFile := fmt.Sprint(outPath, data.TableName, ".gen.go")

modelFile := modelOutPath + data.TableName + ".gen.go"
err = g.output(modelFile, buf.Bytes())
if err != nil {
return err
Expand All @@ -486,6 +393,18 @@ func (g *Generator) generateModelFile() (err error) {
return nil
}

func (g *Generator) getModelOutputPath() (outPath string, err error) {
if strings.Contains(g.ModelPkgPath, "/") {
outPath, err = filepath.Abs(g.ModelPkgPath)
if err != nil {
return "", fmt.Errorf("cannot parse model pkg path: %w", err)
}
} else {
outPath = filepath.Dir(g.OutPath) + "/" + g.ModelPkgPath
}
return outPath + "/", nil
}

// output format and output
func (g *Generator) output(fileName string, content []byte) error {
result, err := imports.Process(fileName, content, nil)
Expand All @@ -498,7 +417,7 @@ func (g *Generator) output(fileName string, content []byte) error {
for i := startLine; i <= endLine; i++ {
fmt.Println(i, line[i])
}
return fmt.Errorf("cannot format struct file: %w", err)
return fmt.Errorf("cannot format file: %w", err)
}
return outputFile(fileName, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, result)
}
Expand Down

0 comments on commit fe851bf

Please sign in to comment.