Skip to content

Commit 592c1d6

Browse files
committed
feat: add overrides config option
1 parent 53fa0b2 commit 592c1d6

File tree

20 files changed

+279
-18
lines changed

20 files changed

+279
-18
lines changed

README.md

+44
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,47 @@ class Status(str, enum.Enum):
7676
OPEN = "op!en"
7777
CLOSED = "clo@sed"
7878
```
79+
80+
### Override Column Types
81+
82+
Option: `overrides`
83+
84+
You can override the SQL to Python type mapping for specific columns using the `overrides` option. This is useful for columns with JSON data or other custom types.
85+
86+
Example configuration:
87+
88+
```yaml
89+
options:
90+
package: authors
91+
emit_pydantic_models: true
92+
overrides:
93+
- column: "some_table.payload"
94+
py_import: "my_lib.models"
95+
py_type: "Payload"
96+
```
97+
98+
This will:
99+
1. Override the column `payload` in `some_table` to use the type `Payload`
100+
2. Add an import for `my_lib.models` to the models file
101+
102+
Example output:
103+
104+
```python
105+
# Code generated by sqlc. DO NOT EDIT.
106+
# versions:
107+
# sqlc v1.28.0
108+
109+
import datetime
110+
import pydantic
111+
from typing import Any
112+
113+
import my_lib.models
114+
115+
116+
class SomeTable(pydantic.BaseModel):
117+
id: int
118+
created_at: datetime.datetime
119+
payload: my_lib.models.Payload
120+
```
121+
122+
This is similar to the [overrides functionality in the Go version of sqlc](https://docs.sqlc.dev/en/stable/howto/overrides.html#overriding-types).

internal/config.go

+16-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
package python
22

3+
type OverrideColumn struct {
4+
Column string `json:"column"`
5+
PyType string `json:"py_type"`
6+
PyImport string `json:"py_import"`
7+
}
8+
39
type Config struct {
4-
EmitExactTableNames bool `json:"emit_exact_table_names"`
5-
EmitSyncQuerier bool `json:"emit_sync_querier"`
6-
EmitAsyncQuerier bool `json:"emit_async_querier"`
7-
Package string `json:"package"`
8-
Out string `json:"out"`
9-
EmitPydanticModels bool `json:"emit_pydantic_models"`
10-
EmitStrEnum bool `json:"emit_str_enum"`
11-
QueryParameterLimit *int32 `json:"query_parameter_limit"`
12-
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"`
10+
EmitExactTableNames bool `json:"emit_exact_table_names"`
11+
EmitSyncQuerier bool `json:"emit_sync_querier"`
12+
EmitAsyncQuerier bool `json:"emit_async_querier"`
13+
Package string `json:"package"`
14+
Out string `json:"out"`
15+
EmitPydanticModels bool `json:"emit_pydantic_models"`
16+
EmitStrEnum bool `json:"emit_str_enum"`
17+
QueryParameterLimit *int32 `json:"query_parameter_limit"`
18+
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"`
19+
Overrides []OverrideColumn `json:"overrides"`
1320
}

internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/emit_str_enum/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Code generated by sqlc. DO NOT EDIT.
2+
# versions:
3+
# sqlc v1.28.0
4+
import pydantic
5+
6+
import my_lib.models
7+
8+
9+
class Book(pydantic.BaseModel):
10+
id: int
11+
payload: my_lib.models.Payload
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Code generated by sqlc. DO NOT EDIT.
2+
# versions:
3+
# sqlc v1.28.0
4+
# source: query.sql
5+
from typing import AsyncIterator, Iterator, Optional
6+
7+
import my_lib.models
8+
import sqlalchemy
9+
import sqlalchemy.ext.asyncio
10+
11+
from db import models
12+
13+
14+
CREATE_BOOK = """-- name: create_book \\:one
15+
INSERT INTO books (payload)
16+
VALUES (:p1)
17+
RETURNING id, payload
18+
"""
19+
20+
21+
GET_BOOK = """-- name: get_book \\:one
22+
SELECT id, payload FROM books
23+
WHERE id = :p1 LIMIT 1
24+
"""
25+
26+
27+
LIST_BOOKS = """-- name: list_books \\:many
28+
SELECT id, payload FROM books
29+
ORDER BY id
30+
"""
31+
32+
33+
class Querier:
34+
def __init__(self, conn: sqlalchemy.engine.Connection):
35+
self._conn = conn
36+
37+
def create_book(self, *, payload: my_lib.models.Payload) -> Optional[models.Book]:
38+
row = self._conn.execute(sqlalchemy.text(CREATE_BOOK), {"p1": payload}).first()
39+
if row is None:
40+
return None
41+
return models.Book(
42+
id=row[0],
43+
payload=row[1],
44+
)
45+
46+
def get_book(self, *, id: int) -> Optional[models.Book]:
47+
row = self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": id}).first()
48+
if row is None:
49+
return None
50+
return models.Book(
51+
id=row[0],
52+
payload=row[1],
53+
)
54+
55+
def list_books(self) -> Iterator[models.Book]:
56+
result = self._conn.execute(sqlalchemy.text(LIST_BOOKS))
57+
for row in result:
58+
yield models.Book(
59+
id=row[0],
60+
payload=row[1],
61+
)
62+
63+
64+
class AsyncQuerier:
65+
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
66+
self._conn = conn
67+
68+
async def create_book(self, *, payload: my_lib.models.Payload) -> Optional[models.Book]:
69+
row = (await self._conn.execute(sqlalchemy.text(CREATE_BOOK), {"p1": payload})).first()
70+
if row is None:
71+
return None
72+
return models.Book(
73+
id=row[0],
74+
payload=row[1],
75+
)
76+
77+
async def get_book(self, *, id: int) -> Optional[models.Book]:
78+
row = (await self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": id})).first()
79+
if row is None:
80+
return None
81+
return models.Book(
82+
id=row[0],
83+
payload=row[1],
84+
)
85+
86+
async def list_books(self) -> AsyncIterator[models.Book]:
87+
result = await self._conn.stream(sqlalchemy.text(LIST_BOOKS))
88+
async for row in result:
89+
yield models.Book(
90+
id=row[0],
91+
payload=row[1],
92+
)

internal/endtoend/testdata/emit_type_overrides/my_lib/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from datetime import date
2+
3+
from pydantic import BaseModel
4+
5+
class Payload(BaseModel):
6+
name: str
7+
release_date: date
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
-- name: GetBook :one
2+
SELECT * FROM books
3+
WHERE id = $1 LIMIT 1;
4+
5+
-- name: ListBooks :many
6+
SELECT * FROM books
7+
ORDER BY id;
8+
9+
-- name: CreateBook :one
10+
INSERT INTO books (payload)
11+
VALUES (sqlc.arg(payload))
12+
RETURNING *;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
CREATE TABLE books (
2+
id SERIAL PRIMARY KEY,
3+
payload JSONB NOT NULL
4+
);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
version: "2"
2+
plugins:
3+
- name: py
4+
wasm:
5+
url: file://../../../../bin/sqlc-gen-python.wasm
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
7+
sql:
8+
- schema: schema.sql
9+
queries: query.sql
10+
engine: postgresql
11+
codegen:
12+
- plugin: py
13+
out: db
14+
options:
15+
package: db
16+
emit_pydantic_models: true
17+
emit_sync_querier: true
18+
emit_async_querier: true
19+
overrides:
20+
- column: "books.payload"
21+
py_import: "my_lib.models"
22+
py_type: "Payload"

internal/endtoend/testdata/exec_result/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/exec_rows/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/gen.go

+34
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,40 @@ func (q Query) ArgDictNode() *pyast.Node {
181181
}
182182

183183
func makePyType(req *plugin.GenerateRequest, col *plugin.Column) pyType {
184+
// Parse the configuration
185+
var conf Config
186+
if len(req.PluginOptions) > 0 {
187+
if err := json.Unmarshal(req.PluginOptions, &conf); err != nil {
188+
log.Printf("failed to parse plugin options: %s", err)
189+
}
190+
}
191+
192+
// Check for overrides
193+
if len(conf.Overrides) > 0 && col.Table != nil {
194+
tableName := col.Table.Name
195+
if col.Table.Schema != "" && col.Table.Schema != req.Catalog.DefaultSchema {
196+
tableName = col.Table.Schema + "." + tableName
197+
}
198+
199+
// Look for a matching override
200+
for _, override := range conf.Overrides {
201+
overrideKey := tableName + "." + col.Name
202+
if override.Column == overrideKey {
203+
// Found a match, use the override
204+
typeStr := override.PyType
205+
if override.PyImport != "" && !strings.Contains(typeStr, ".") {
206+
typeStr = override.PyImport + "." + override.PyType
207+
}
208+
return pyType{
209+
InnerType: typeStr,
210+
IsArray: col.IsArray,
211+
IsNull: !col.NotNull,
212+
}
213+
}
214+
}
215+
}
216+
217+
// No override found, use the standard type mapping
184218
typ := pyInnerType(req, col)
185219
return pyType{
186220
InnerType: typ,

internal/imports.go

+28
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ func (i *importer) modelImportSpecs() (map[string]importSpec, map[string]importS
9797

9898
pkg := make(map[string]importSpec)
9999

100+
// Add custom imports from overrides
101+
for _, override := range i.C.Overrides {
102+
if override.PyImport != "" {
103+
// Check if it's a standard module or a package import
104+
if strings.Contains(override.PyImport, ".") {
105+
// It's a package import
106+
pkg[override.PyImport] = importSpec{Module: override.PyImport}
107+
} else {
108+
// It's a standard import
109+
std[override.PyImport] = importSpec{Module: override.PyImport}
110+
}
111+
}
112+
}
113+
100114
return std, pkg
101115
}
102116

@@ -167,6 +181,20 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map
167181
}
168182
}
169183

184+
// Add custom imports from overrides for query files
185+
for _, override := range i.C.Overrides {
186+
if override.PyImport != "" {
187+
// Check if it's a standard module or a package import
188+
if strings.Contains(override.PyImport, ".") {
189+
// It's a package import
190+
pkg[override.PyImport] = importSpec{Module: override.PyImport}
191+
} else {
192+
// It's a standard import
193+
std[override.PyImport] = importSpec{Module: override.PyImport}
194+
}
195+
}
196+
}
197+
170198
return std, pkg
171199
}
172200

0 commit comments

Comments
 (0)