Skip to content

Commit 26f9f21

Browse files
committed
feat: add overrides config option
1 parent 53fa0b2 commit 26f9f21

File tree

20 files changed

+288
-18
lines changed

20 files changed

+288
-18
lines changed

README.md

+44
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,47 @@ class Status(str, enum.Enum):
7676
OPEN = "op!en"
7777
CLOSED = "clo@sed"
7878
```
79+
80+
### Override Column Types
81+
82+
Option: `overrides`
83+
84+
You can override the SQL to Python type mapping for specific columns using the `overrides` option. This is useful for columns with JSON data or other custom types.
85+
86+
Example configuration:
87+
88+
```yaml
89+
options:
90+
package: authors
91+
emit_pydantic_models: true
92+
overrides:
93+
- column: "some_table.payload"
94+
py_import: "my_lib.models"
95+
py_type: "Payload"
96+
```
97+
98+
This will:
99+
1. Override the column `payload` in `some_table` to use the type `Payload`
100+
2. Add an import for `my_lib.models` to the models file
101+
102+
Example output:
103+
104+
```python
105+
# Code generated by sqlc. DO NOT EDIT.
106+
# versions:
107+
# sqlc v1.28.0
108+
109+
import datetime
110+
import pydantic
111+
from typing import Any
112+
113+
import my_lib.models
114+
115+
116+
class SomeTable(pydantic.BaseModel):
117+
id: int
118+
created_at: datetime.datetime
119+
payload: my_lib.models.Payload
120+
```
121+
122+
This is similar to the [overrides functionality in the Go version of sqlc](https://docs.sqlc.dev/en/stable/howto/overrides.html#overriding-types).

internal/config.go

+16-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
package python
22

3+
type OverrideColumn struct {
4+
Column string `json:"column"`
5+
PyType string `json:"py_type"`
6+
PyImport string `json:"py_import"`
7+
}
8+
39
type Config struct {
4-
EmitExactTableNames bool `json:"emit_exact_table_names"`
5-
EmitSyncQuerier bool `json:"emit_sync_querier"`
6-
EmitAsyncQuerier bool `json:"emit_async_querier"`
7-
Package string `json:"package"`
8-
Out string `json:"out"`
9-
EmitPydanticModels bool `json:"emit_pydantic_models"`
10-
EmitStrEnum bool `json:"emit_str_enum"`
11-
QueryParameterLimit *int32 `json:"query_parameter_limit"`
12-
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"`
10+
EmitExactTableNames bool `json:"emit_exact_table_names"`
11+
EmitSyncQuerier bool `json:"emit_sync_querier"`
12+
EmitAsyncQuerier bool `json:"emit_async_querier"`
13+
Package string `json:"package"`
14+
Out string `json:"out"`
15+
EmitPydanticModels bool `json:"emit_pydantic_models"`
16+
EmitStrEnum bool `json:"emit_str_enum"`
17+
QueryParameterLimit *int32 `json:"query_parameter_limit"`
18+
InflectionExcludeTableNames []string `json:"inflection_exclude_table_names"`
19+
Overrides []OverrideColumn `json:"overrides"`
1320
}

internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/emit_str_enum/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Code generated by sqlc. DO NOT EDIT.
2+
# versions:
3+
# sqlc v1.28.0
4+
import datetime
5+
import pydantic
6+
7+
import my_lib.models
8+
9+
10+
class Book(pydantic.BaseModel):
11+
id: int
12+
created_at: datetime.datetime
13+
payload: my_lib.models.Payload
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Code generated by sqlc. DO NOT EDIT.
2+
# versions:
3+
# sqlc v1.28.0
4+
# source: query.sql
5+
from typing import AsyncIterator, Iterator, Optional
6+
7+
import my_lib.models
8+
import sqlalchemy
9+
import sqlalchemy.ext.asyncio
10+
11+
from db import models
12+
13+
14+
CREATE_BOOK = """-- name: create_book \\:one
15+
INSERT INTO books (payload)
16+
VALUES (:p1)
17+
RETURNING id, created_at, payload
18+
"""
19+
20+
21+
GET_BOOK = """-- name: get_book \\:one
22+
SELECT id, created_at, payload FROM books
23+
WHERE id = :p1 LIMIT 1
24+
"""
25+
26+
27+
LIST_BOOKS = """-- name: list_books \\:many
28+
SELECT id, created_at, payload FROM books
29+
ORDER BY id
30+
"""
31+
32+
33+
class Querier:
34+
def __init__(self, conn: sqlalchemy.engine.Connection):
35+
self._conn = conn
36+
37+
def create_book(self, *, payload: my_lib.models.Payload) -> Optional[models.Book]:
38+
row = self._conn.execute(sqlalchemy.text(CREATE_BOOK), {"p1": payload}).first()
39+
if row is None:
40+
return None
41+
return models.Book(
42+
id=row[0],
43+
created_at=row[1],
44+
payload=row[2],
45+
)
46+
47+
def get_book(self, *, id: int) -> Optional[models.Book]:
48+
row = self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": id}).first()
49+
if row is None:
50+
return None
51+
return models.Book(
52+
id=row[0],
53+
created_at=row[1],
54+
payload=row[2],
55+
)
56+
57+
def list_books(self) -> Iterator[models.Book]:
58+
result = self._conn.execute(sqlalchemy.text(LIST_BOOKS))
59+
for row in result:
60+
yield models.Book(
61+
id=row[0],
62+
created_at=row[1],
63+
payload=row[2],
64+
)
65+
66+
67+
class AsyncQuerier:
68+
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
69+
self._conn = conn
70+
71+
async def create_book(self, *, payload: my_lib.models.Payload) -> Optional[models.Book]:
72+
row = (await self._conn.execute(sqlalchemy.text(CREATE_BOOK), {"p1": payload})).first()
73+
if row is None:
74+
return None
75+
return models.Book(
76+
id=row[0],
77+
created_at=row[1],
78+
payload=row[2],
79+
)
80+
81+
async def get_book(self, *, id: int) -> Optional[models.Book]:
82+
row = (await self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": id})).first()
83+
if row is None:
84+
return None
85+
return models.Book(
86+
id=row[0],
87+
created_at=row[1],
88+
payload=row[2],
89+
)
90+
91+
async def list_books(self) -> AsyncIterator[models.Book]:
92+
result = await self._conn.stream(sqlalchemy.text(LIST_BOOKS))
93+
async for row in result:
94+
yield models.Book(
95+
id=row[0],
96+
created_at=row[1],
97+
payload=row[2],
98+
)

internal/endtoend/testdata/emit_type_overrides/my_lib/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from datetime import date
2+
3+
from pydantic import BaseModel
4+
5+
class Payload(BaseModel):
6+
name: str
7+
release_date: date
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
-- name: GetBook :one
2+
SELECT * FROM books
3+
WHERE id = $1 LIMIT 1;
4+
5+
-- name: ListBooks :many
6+
SELECT * FROM books
7+
ORDER BY id;
8+
9+
-- name: CreateBook :one
10+
INSERT INTO books (payload)
11+
VALUES (sqlc.arg(payload))
12+
RETURNING *;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
CREATE TABLE books (
2+
id SERIAL PRIMARY KEY,
3+
created_at TIMESTAMPTZ NOT NULL,
4+
payload JSONB NOT NULL
5+
);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
version: "2"
2+
plugins:
3+
- name: py
4+
wasm:
5+
url: file://../../../../bin/sqlc-gen-python.wasm
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
7+
sql:
8+
- schema: schema.sql
9+
queries: query.sql
10+
engine: postgresql
11+
codegen:
12+
- plugin: py
13+
out: db
14+
options:
15+
package: db
16+
emit_pydantic_models: true
17+
emit_sync_querier: true
18+
emit_async_querier: true
19+
overrides:
20+
- column: "books.payload"
21+
py_import: "my_lib.models"
22+
py_type: "Payload"

internal/endtoend/testdata/exec_result/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/exec_rows/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca"
6+
sha256: "31717935ced1923fdaea102da0d345776c173df7fa6668120be8f7900a7fe938"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/gen.go

+34
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,40 @@ func (q Query) ArgDictNode() *pyast.Node {
181181
}
182182

183183
func makePyType(req *plugin.GenerateRequest, col *plugin.Column) pyType {
184+
// Parse the configuration
185+
var conf Config
186+
if len(req.PluginOptions) > 0 {
187+
if err := json.Unmarshal(req.PluginOptions, &conf); err != nil {
188+
log.Printf("failed to parse plugin options: %s", err)
189+
}
190+
}
191+
192+
// Check for overrides
193+
if len(conf.Overrides) > 0 && col.Table != nil {
194+
tableName := col.Table.Name
195+
if col.Table.Schema != "" && col.Table.Schema != req.Catalog.DefaultSchema {
196+
tableName = col.Table.Schema + "." + tableName
197+
}
198+
199+
// Look for a matching override
200+
for _, override := range conf.Overrides {
201+
overrideKey := tableName + "." + col.Name
202+
if override.Column == overrideKey {
203+
// Found a match, use the override
204+
typeStr := override.PyType
205+
if override.PyImport != "" && !strings.Contains(typeStr, ".") {
206+
typeStr = override.PyImport + "." + override.PyType
207+
}
208+
return pyType{
209+
InnerType: typeStr,
210+
IsArray: col.IsArray,
211+
IsNull: !col.NotNull,
212+
}
213+
}
214+
}
215+
}
216+
217+
// No override found, use the standard type mapping
184218
typ := pyInnerType(req, col)
185219
return pyType{
186220
InnerType: typ,

0 commit comments

Comments
 (0)