diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 6aca91bd0c..26c0a164f0 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -194,6 +194,9 @@ func (i *importer) interfaceImports() fileImports { overrideTypes[o.GoTypeName] = o.GoImportPath } + if uses("pgtype.") { + pkg[ImportSpec{Path: "github.com/jackc/pgtype"}] = struct{}{} + } _, overrideNullTime := overrideTypes["pq.NullTime"] if uses("pq.NullTime") && !overrideNullTime { pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} @@ -263,6 +266,9 @@ func (i *importer) modelImports() fileImports { if i.usesType("uuid.UUID") && !overrideUUID { pkg[ImportSpec{Path: "github.com/google/uuid"}] = struct{}{} } + if i.usesType("pgtype.") { + pkg[ImportSpec{Path: "github.com/jackc/pgtype"}] = struct{}{} + } for _, o := range i.Settings.Overrides { if o.GoBasicType || o.GoTypeName == "" { @@ -369,7 +375,6 @@ func (i *importer) queryImports(filename string) fileImports { if uses("sql.Null") { std["database/sql"] = struct{}{} } - sqlpkg := SQLPackageFromString(i.Settings.Go.SQLPackage) for _, q := range gq { @@ -396,6 +401,9 @@ func (i *importer) queryImports(filename string) fileImports { overrideTypes[o.GoTypeName] = o.GoImportPath } + if uses("pgtype.") { + pkg[ImportSpec{Path: "github.com/jackc/pgtype"}] = struct{}{} + } if sliceScan() && sqlpkg != SQLPackagePGX { pkg[ImportSpec{Path: "github.com/lib/pq"}] = struct{}{} } diff --git a/internal/codegen/golang/postgresql_type.go b/internal/codegen/golang/postgresql_type.go index 1e893626df..3e032a5580 100644 --- a/internal/codegen/golang/postgresql_type.go +++ b/internal/codegen/golang/postgresql_type.go @@ -105,8 +105,11 @@ func postgresType(r *compiler.Result, col *compiler.Column, settings config.Comb case "uuid": return "uuid.UUID" - case "inet", "cidr": - return "net.IP" + case "inet": + return "pgtype.Inet" + + case "cidr": + return "pgtype.CIDR" case "macaddr", "macaddr8": return "net.HardwareAddr" diff --git a/internal/endtoend/testdata/func_return/posgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/func_return/posgresql/pgx/go/query.sql.go index b8775974a0..44d4b34f13 100644 --- a/internal/endtoend/testdata/func_return/posgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/func_return/posgresql/pgx/go/query.sql.go @@ -5,7 +5,8 @@ package querytest import ( "context" - "net" + + "github.com/jackc/pgtype" ) const generateSeries = `-- name: GenerateSeries :many @@ -15,7 +16,7 @@ LIMIT 1 ` type GenerateSeriesParams struct { - Column1 net.IP + Column1 pgtype.Inet Column2 int32 } diff --git a/internal/endtoend/testdata/func_return/posgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/func_return/posgresql/stdlib/go/query.sql.go index 990f335916..046263434f 100644 --- a/internal/endtoend/testdata/func_return/posgresql/stdlib/go/query.sql.go +++ b/internal/endtoend/testdata/func_return/posgresql/stdlib/go/query.sql.go @@ -5,7 +5,8 @@ package querytest import ( "context" - "net" + + "github.com/jackc/pgtype" ) const generateSeries = `-- name: GenerateSeries :many @@ -15,7 +16,7 @@ LIMIT 1 ` type GenerateSeriesParams struct { - Column1 net.IP + Column1 pgtype.Inet Column2 int32 } diff --git a/internal/endtoend/testdata/ipaddr/pgx/go/models.go b/internal/endtoend/testdata/ipaddr/pgx/go/models.go index c9b78920d4..ed7771b7fd 100644 --- a/internal/endtoend/testdata/ipaddr/pgx/go/models.go +++ b/internal/endtoend/testdata/ipaddr/pgx/go/models.go @@ -3,11 +3,12 @@ package querytest import ( - "net" + "github.com/jackc/pgtype" ) type Foo struct { - Bar bool - Inet net.IP - Cidr net.IP + PresentIp pgtype.Inet + NullableIp pgtype.Inet + PresentCidr pgtype.CIDR + NullableCidr pgtype.CIDR } diff --git a/internal/endtoend/testdata/ipaddr/pgx/go/query.sql.go b/internal/endtoend/testdata/ipaddr/pgx/go/query.sql.go index cb2cda6b98..c1cdd88dad 100644 --- a/internal/endtoend/testdata/ipaddr/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/ipaddr/pgx/go/query.sql.go @@ -8,7 +8,7 @@ import ( ) const get = `-- name: Get :many -SELECT bar, "inet", "cidr" FROM foo LIMIT $1 +SELECT present_ip, nullable_ip, present_cidr, nullable_cidr FROM foo LIMIT $1 ` func (q *Queries) Get(ctx context.Context, limit int32) ([]Foo, error) { @@ -20,7 +20,12 @@ func (q *Queries) Get(ctx context.Context, limit int32) ([]Foo, error) { var items []Foo for rows.Next() { var i Foo - if err := rows.Scan(&i.Bar, &i.Inet, &i.Cidr); err != nil { + if err := rows.Scan( + &i.PresentIp, + &i.NullableIp, + &i.PresentCidr, + &i.NullableCidr, + ); err != nil { return nil, err } items = append(items, i) diff --git a/internal/endtoend/testdata/ipaddr/pgx/query.sql b/internal/endtoend/testdata/ipaddr/pgx/query.sql index 11b6b5c820..daec2e4cae 100644 --- a/internal/endtoend/testdata/ipaddr/pgx/query.sql +++ b/internal/endtoend/testdata/ipaddr/pgx/query.sql @@ -1,4 +1,9 @@ -CREATE TABLE foo (bar bool not null, "inet" inet not null, "cidr" cidr not null); +CREATE TABLE foo ( + present_ip inet not null, + nullable_ip inet, + present_cidr cidr not null, + nullable_cidr cidr +); -- name: Get :many -SELECT bar, "inet", "cidr" FROM foo LIMIT $1; +SELECT * FROM foo LIMIT $1; diff --git a/internal/endtoend/testdata/ipaddr/stdlib/go/models.go b/internal/endtoend/testdata/ipaddr/stdlib/go/models.go index c9b78920d4..ed7771b7fd 100644 --- a/internal/endtoend/testdata/ipaddr/stdlib/go/models.go +++ b/internal/endtoend/testdata/ipaddr/stdlib/go/models.go @@ -3,11 +3,12 @@ package querytest import ( - "net" + "github.com/jackc/pgtype" ) type Foo struct { - Bar bool - Inet net.IP - Cidr net.IP + PresentIp pgtype.Inet + NullableIp pgtype.Inet + PresentCidr pgtype.CIDR + NullableCidr pgtype.CIDR } diff --git a/internal/endtoend/testdata/ipaddr/stdlib/go/query.sql.go b/internal/endtoend/testdata/ipaddr/stdlib/go/query.sql.go index 1bdc311a36..4020f44c12 100644 --- a/internal/endtoend/testdata/ipaddr/stdlib/go/query.sql.go +++ b/internal/endtoend/testdata/ipaddr/stdlib/go/query.sql.go @@ -8,7 +8,7 @@ import ( ) const get = `-- name: Get :many -SELECT bar, "inet", "cidr" FROM foo LIMIT $1 +SELECT present_ip, nullable_ip, present_cidr, nullable_cidr FROM foo LIMIT $1 ` func (q *Queries) Get(ctx context.Context, limit int32) ([]Foo, error) { @@ -20,7 +20,12 @@ func (q *Queries) Get(ctx context.Context, limit int32) ([]Foo, error) { var items []Foo for rows.Next() { var i Foo - if err := rows.Scan(&i.Bar, &i.Inet, &i.Cidr); err != nil { + if err := rows.Scan( + &i.PresentIp, + &i.NullableIp, + &i.PresentCidr, + &i.NullableCidr, + ); err != nil { return nil, err } items = append(items, i) diff --git a/internal/endtoend/testdata/ipaddr/stdlib/query.sql b/internal/endtoend/testdata/ipaddr/stdlib/query.sql index 11b6b5c820..daec2e4cae 100644 --- a/internal/endtoend/testdata/ipaddr/stdlib/query.sql +++ b/internal/endtoend/testdata/ipaddr/stdlib/query.sql @@ -1,4 +1,9 @@ -CREATE TABLE foo (bar bool not null, "inet" inet not null, "cidr" cidr not null); +CREATE TABLE foo ( + present_ip inet not null, + nullable_ip inet, + present_cidr cidr not null, + nullable_cidr cidr +); -- name: Get :many -SELECT bar, "inet", "cidr" FROM foo LIMIT $1; +SELECT * FROM foo LIMIT $1;