Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleconroy committed Mar 28, 2024
1 parent aa5e345 commit 0dfbfb6
Show file tree
Hide file tree
Showing 16 changed files with 214 additions and 133 deletions.
46 changes: 7 additions & 39 deletions internal/driver.go
Original file line number Diff line number Diff line change
@@ -1,46 +1,14 @@
package golang

type SQLDriver string
import "github.com/sqlc-dev/sqlc-gen-go/internal/opts"

const (
SQLPackagePGXV4 string = "pgx/v4"
SQLPackagePGXV5 string = "pgx/v5"
SQLPackageStandard string = "database/sql"
)

const (
SQLDriverPGXV4 SQLDriver = "github.com/jackc/pgx/v4"
SQLDriverPGXV5 = "github.com/jackc/pgx/v5"
SQLDriverLibPQ = "github.com/lib/pq"
SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql"
)

func parseDriver(sqlPackage string) SQLDriver {
func parseDriver(sqlPackage string) opts.SQLDriver {
switch sqlPackage {
case SQLPackagePGXV4:
return SQLDriverPGXV4
case SQLPackagePGXV5:
return SQLDriverPGXV5
default:
return SQLDriverLibPQ
}
}

func (d SQLDriver) IsPGX() bool {
return d == SQLDriverPGXV4 || d == SQLDriverPGXV5
}

func (d SQLDriver) IsGoSQLDriverMySQL() bool {
return d == SQLDriverGoSQLDriverMySQL
}

func (d SQLDriver) Package() string {
switch d {
case SQLDriverPGXV4:
return SQLPackagePGXV4
case SQLDriverPGXV5:
return SQLPackagePGXV5
case opts.SQLPackagePGXV4:
return opts.SQLDriverPGXV4
case opts.SQLPackagePGXV5:
return opts.SQLDriverPGXV5
default:
return SQLPackageStandard
return opts.SQLDriverLibPQ
}
}
2 changes: 1 addition & 1 deletion internal/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"sort"
"strings"

"github.com/sqlc-dev/plugin-sdk-go/plugin"
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
"github.com/sqlc-dev/plugin-sdk-go/plugin"
)

type Field struct {
Expand Down
13 changes: 7 additions & 6 deletions internal/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ import (
"strings"
"text/template"

"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
"github.com/sqlc-dev/plugin-sdk-go/sdk"
"github.com/sqlc-dev/plugin-sdk-go/metadata"
"github.com/sqlc-dev/plugin-sdk-go/plugin"
"github.com/sqlc-dev/plugin-sdk-go/sdk"
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
)

type tmplCtx struct {
Q string
Package string
SQLDriver SQLDriver
SQLDriver opts.SQLDriver
Enums []Enum
Structs []Struct
GoQueries []Query
Expand Down Expand Up @@ -189,15 +189,15 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
OmitSqlcVersion: options.OmitSqlcVersion,
}

if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != SQLDriverGoSQLDriverMySQL {
if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != opts.SQLDriverGoSQLDriverMySQL {
return nil, errors.New(":copyfrom is only supported by pgx and github.com/go-sql-driver/mysql")
}

if tctx.UsesCopyFrom && options.SqlDriver == SQLDriverGoSQLDriverMySQL {
if tctx.UsesCopyFrom && options.SqlDriver == opts.SQLDriverGoSQLDriverMySQL {
if err := checkNoTimesForMySQLCopyFrom(queries); err != nil {
return nil, err
}
tctx.SQLDriver = SQLDriverGoSQLDriverMySQL
tctx.SQLDriver = opts.SQLDriverGoSQLDriverMySQL
}

if tctx.UsesBatch && !tctx.SQLDriver.IsPGX() {
Expand All @@ -209,6 +209,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum,
"comment": sdk.DoubleSlashComment,
"escape": sdk.EscapeBacktick,
"imports": i.Imports,
"hasImports": i.HasImports,
"hasPrefix": strings.HasPrefix,

// These methods are Go specific, they do not belong in the codegen package
Expand Down
4 changes: 2 additions & 2 deletions internal/go_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package golang
import (
"strings"

"github.com/sqlc-dev/plugin-sdk-go/plugin"
"github.com/sqlc-dev/plugin-sdk-go/sdk"
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
"github.com/sqlc-dev/plugin-sdk-go/sdk"
"github.com/sqlc-dev/plugin-sdk-go/plugin"
)

func addExtraGoStructTags(tags map[string]string, req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) {
Expand Down
23 changes: 14 additions & 9 deletions internal/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"sort"
"strings"

"github.com/sqlc-dev/plugin-sdk-go/metadata"
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
"github.com/sqlc-dev/plugin-sdk-go/metadata"
)

type fileImports struct {
Expand Down Expand Up @@ -75,6 +75,11 @@ func (i *importer) usesType(typ string) bool {
return false
}

func (i *importer) HasImports(filename string) bool {
imports := i.Imports(filename)
return len(imports[0]) != 0 || len(imports[1]) != 0
}

func (i *importer) Imports(filename string) [][]ImportSpec {
dbFileName := "db.go"
if i.Options.OutputDbFileName != "" {
Expand Down Expand Up @@ -121,10 +126,10 @@ func (i *importer) dbImports() fileImports {

sqlpkg := parseDriver(i.Options.SqlPackage)
switch sqlpkg {
case SQLDriverPGXV4:
case opts.SQLDriverPGXV4:
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgconn"})
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v4"})
case SQLDriverPGXV5:
case opts.SQLDriverPGXV5:
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"})
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5"})
default:
Expand Down Expand Up @@ -167,9 +172,9 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool
for _, q := range queries {
if q.Cmd == metadata.CmdExecResult {
switch sqlpkg {
case SQLDriverPGXV4:
case opts.SQLDriverPGXV4:
pkg[ImportSpec{Path: "github.com/jackc/pgconn"}] = struct{}{}
case SQLDriverPGXV5:
case opts.SQLDriverPGXV5:
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"}] = struct{}{}
default:
std["database/sql"] = struct{}{}
Expand All @@ -184,7 +189,7 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool
}

if uses("pgtype.") {
if sqlpkg == SQLDriverPGXV5 {
if sqlpkg == opts.SQLDriverPGXV5 {
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgtype"}] = struct{}{}
} else {
pkg[ImportSpec{Path: "github.com/jackc/pgtype"}] = struct{}{}
Expand Down Expand Up @@ -424,7 +429,7 @@ func (i *importer) copyfromImports() fileImports {
})

std["context"] = struct{}{}
if i.Options.SqlDriver == SQLDriverGoSQLDriverMySQL {
if i.Options.SqlDriver == opts.SQLDriverGoSQLDriverMySQL {
std["io"] = struct{}{}
std["fmt"] = struct{}{}
std["sync/atomic"] = struct{}{}
Expand Down Expand Up @@ -476,9 +481,9 @@ func (i *importer) batchImports() fileImports {
std["errors"] = struct{}{}
sqlpkg := parseDriver(i.Options.SqlPackage)
switch sqlpkg {
case SQLDriverPGXV4:
case opts.SQLDriverPGXV4:
pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{}
case SQLDriverPGXV5:
case opts.SQLDriverPGXV5:
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5"}] = struct{}{}
}

Expand Down
29 changes: 23 additions & 6 deletions internal/mysql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package golang
import (
"log"

"github.com/sqlc-dev/plugin-sdk-go/plugin"
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
"github.com/sqlc-dev/plugin-sdk-go/sdk"
"github.com/sqlc-dev/sqlc-gen-go/internal/debug"
"github.com/sqlc-dev/sqlc-gen-go/internal/opts"
"github.com/sqlc-dev/plugin-sdk-go/plugin"
)

func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
Expand All @@ -31,14 +31,31 @@ func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.C
} else {
if notNull {
if unsigned {
return "uint32"
return "uint8"
}
return "int32"
return "int8"
}
// The database/sql package does not have a sql.NullInt8 type, so we
// use the smallest type they have which is NullInt16
return "sql.NullInt16"
}

case "year":
if notNull {
return "int16"
}
return "sql.NullInt16"

case "smallint":
if notNull {
if unsigned {
return "uint16"
}
return "sql.NullInt32"
return "int16"
}
return "sql.NullInt16"

case "int", "integer", "smallint", "mediumint", "year":
case "int", "integer", "mediumint":
if notNull {
if unsigned {
return "uint32"
Expand Down
64 changes: 64 additions & 0 deletions internal/opts/enum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package opts

import "fmt"

type SQLDriver string

const (
SQLPackagePGXV4 string = "pgx/v4"
SQLPackagePGXV5 string = "pgx/v5"
SQLPackageStandard string = "database/sql"
)

var validPackages = map[string]struct{}{
string(SQLPackagePGXV4): {},
string(SQLPackagePGXV5): {},
string(SQLPackageStandard): {},
}

func validatePackage(sqlPackage string) error {
if _, found := validPackages[sqlPackage]; !found {
return fmt.Errorf("unknown SQL package: %s", sqlPackage)
}
return nil
}

const (
SQLDriverPGXV4 SQLDriver = "github.com/jackc/pgx/v4"
SQLDriverPGXV5 = "github.com/jackc/pgx/v5"
SQLDriverLibPQ = "github.com/lib/pq"
SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql"
)

var validDrivers = map[string]struct{}{
string(SQLDriverPGXV4): {},
string(SQLDriverPGXV5): {},
string(SQLDriverLibPQ): {},
string(SQLDriverGoSQLDriverMySQL): {},
}

func validateDriver(sqlDriver string) error {
if _, found := validDrivers[sqlDriver]; !found {
return fmt.Errorf("unknown SQL driver: %s", sqlDriver)
}
return nil
}

func (d SQLDriver) IsPGX() bool {
return d == SQLDriverPGXV4 || d == SQLDriverPGXV5
}

func (d SQLDriver) IsGoSQLDriverMySQL() bool {
return d == SQLDriverGoSQLDriverMySQL
}

func (d SQLDriver) Package() string {
switch d {
case SQLDriverPGXV4:
return SQLPackagePGXV4
case SQLDriverPGXV5:
return SQLPackagePGXV5
default:
return SQLPackageStandard
}
}
12 changes: 12 additions & 0 deletions internal/opts/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,18 @@ func parseOpts(req *plugin.GenerateRequest) (*Options, error) {
}
}

if options.SqlPackage != "" {
if err := validatePackage(options.SqlPackage); err != nil {
return nil, fmt.Errorf("invalid options: %s", err)
}
}

if options.SqlDriver != "" {
if err := validateDriver(options.SqlDriver); err != nil {
return nil, fmt.Errorf("invalid options: %s", err)
}
}

if options.QueryParameterLimit == nil {
options.QueryParameterLimit = new(int32)
*options.QueryParameterLimit = 1
Expand Down
Loading

0 comments on commit 0dfbfb6

Please sign in to comment.