Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import re
import subprocess
from collections import defaultdict
from pathlib import Path
from textwrap import dedent
from typing import (
Expand All @@ -24,6 +25,7 @@
FileContents,
HandshakeType,
ListTypeExpr,
LiteralType,
LiteralTypeExpr,
ModuleName,
NoneTypeExpr,
Expand All @@ -33,6 +35,7 @@
TypeName,
UnionTypeExpr,
extract_inner_type,
normalize_special_chars,
render_literal_type,
render_type_expr,
)
Expand Down Expand Up @@ -396,9 +399,12 @@ def {_field_name}(
case NoneTypeExpr():
typeddict_encoder.append("None")
case other:
_o2: DictTypeExpr | OpenUnionTypeExpr | UnionTypeExpr = (
other
)
_o2: (
DictTypeExpr
| OpenUnionTypeExpr
| UnionTypeExpr
| LiteralType
) = other
raise ValueError(f"What does it mean to have {_o2} here?")
if permit_unknown_members:
union = _make_open_union_type_expr(any_of)
Expand Down Expand Up @@ -491,7 +497,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
return (NoneTypeExpr(), [], [], set())
elif type.type == "Date":
typeddict_encoder.append("TODO: dstewart")
return (TypeName("datetime.datetime"), [], [], set())
return (LiteralType("datetime.datetime"), [], [], set())
elif type.type == "array" and type.items:
type_name, module_info, type_chunks, encoder_names = encode_type(
type.items,
Expand Down Expand Up @@ -524,6 +530,9 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
# lambda x: ... vs lambda _: {}
needs_binding = False
encoder_names = set()
# Track effective field names to detect collisions after normalization
# Maps effective name -> list of original field names
effective_field_names: defaultdict[str, list[str]] = defaultdict(list)
if type.properties:
needs_binding = True
typeddict_encoder.append("{")
Expand Down Expand Up @@ -653,19 +662,37 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
value = ""
if base_model != "TypedDict":
value = f"= {field_value}"
# Track $kind -> "kind" mapping for collision detection
effective_field_names["kind"].append(name)

current_chunks.append(
f" kind: Annotated[{render_type_expr(type_name)}, Field(alias={
repr(name)
})]{value}"
)
else:
specialized_name = normalize_special_chars(name)
effective_name = name
extras = []
if name != specialized_name:
if base_model != "BaseModel":
# TODO: alias support for TypedDict
raise ValueError(
f"Field {name} is not a valid Python identifier, but it is in the schema" # noqa: E501
)
# Pydantic doesn't allow leading underscores in field names
effective_name = specialized_name.lstrip("_")
extras.append(f"alias={repr(name)}")

effective_field_names[effective_name].append(name)

if name not in type.required:
if base_model == "TypedDict":
current_chunks.append(
reindent(
" ",
f"""\
{name}: NotRequired[{
{effective_name}: NotRequired[{
render_type_expr(
UnionTypeExpr([type_name, NoneTypeExpr()])
)
Expand All @@ -674,11 +701,13 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
)
)
else:
extras.append("default=None")

current_chunks.append(
reindent(
" ",
f"""\
{name}: {
{effective_name}: {
render_type_expr(
UnionTypeExpr(
[
Expand All @@ -687,15 +716,30 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
]
)
)
} = None
} = Field({", ".join(extras)})
""",
)
)
else:
extras_str = ""
if len(extras) != 0:
extras_str = f" = Field({', '.join(extras)})"

current_chunks.append(
f" {name}: {render_type_expr(type_name)}"
f" {effective_name}: {render_type_expr(type_name)}{extras_str}" # noqa: E501
)
typeddict_encoder.append(",")

# Check for field name collisions after processing all fields
for effective_name, original_names in effective_field_names.items():
if len(original_names) > 1:
error_msg = (
f"Field name collision: fields {original_names} all normalize "
f"to the same effective name '{effective_name}'"
)

raise ValueError(error_msg)

typeddict_encoder.append("}")
# exclude_none
typeddict_encoder = (
Expand Down
31 changes: 30 additions & 1 deletion src/replit_river/codegen/typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import NewType, assert_never, cast

SPECIAL_CHARS = [".", "-", ":", "/", "@", " ", "$", "!", "?", "=", "&", "|", "~", "`"]

ModuleName = NewType("ModuleName", str)
ClassName = NewType("ClassName", str)
FileContents = NewType("FileContents", str)
Expand All @@ -23,6 +25,20 @@ def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)


@dataclass(frozen=True)
class LiteralType:
value: str

def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")

def __eq__(self, other: object) -> bool:
return isinstance(other, LiteralType) and other.value == self.value

def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)


@dataclass(frozen=True)
class NoneTypeExpr:
def __str__(self) -> str:
Expand Down Expand Up @@ -111,6 +127,7 @@ def __lt__(self, other: object) -> bool:

TypeExpression = (
TypeName
| LiteralType
| NoneTypeExpr
| DictTypeExpr
| ListTypeExpr
Expand Down Expand Up @@ -145,6 +162,12 @@ def work(
raise ValueError("Incoherent state when trying to flatten unions")


def normalize_special_chars(value: str) -> str:
for char in SPECIAL_CHARS:
value = value.replace(char, "_")
return value


def render_type_expr(value: TypeExpression) -> str:
match _flatten_nested_unions(value):
case DictTypeExpr(nested):
Expand Down Expand Up @@ -192,7 +215,9 @@ def render_type_expr(value: TypeExpression) -> str:
"]"
)
case TypeName(name):
return name
return normalize_special_chars(name)
case LiteralType(literal_value):
return literal_value
case NoneTypeExpr():
return "None"
case other:
Expand Down Expand Up @@ -223,6 +248,10 @@ def extract_inner_type(value: TypeExpression) -> TypeName:
)
case TypeName(name):
return TypeName(name)
case LiteralType(name):
raise ValueError(
f"Attempting to extract from a literal type: {repr(value)}"
)
case NoneTypeExpr():
raise ValueError(
f"Attempting to extract from a literal 'None': {repr(value)}",
Expand Down
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime, timezone
from typing import Any, Literal, Mapping

import nanoid
Expand Down Expand Up @@ -55,7 +56,11 @@ def deserialize_request(request: dict) -> str:


def serialize_response(response: str) -> dict:
return {"data": response}
return {
"data": response,
"data2": datetime.now(timezone.utc),
"data-3": {"data-test": "test"},
}


def deserialize_response(response: dict) -> str:
Expand Down
8 changes: 8 additions & 0 deletions tests/v1/codegen/rpc/generated/test_service/rpc_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def encode_Rpc_MethodInput(
for (k, v) in (
{
"data": x.get("data"),
"data2": x.get("data2"),
}
).items()
if v is not None
Expand All @@ -38,10 +39,17 @@ def encode_Rpc_MethodInput(

class Rpc_MethodInput(TypedDict):
data: str
data2: datetime.datetime


class Rpc_MethodOutputData_3(BaseModel):
data_test: str | None = Field(alias="data-test", default=None)


class Rpc_MethodOutput(BaseModel):
data: str
data_3: Rpc_MethodOutputData_3 = Field(alias="data-3")
data2: datetime.datetime


Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter(
Expand Down
30 changes: 30 additions & 0 deletions tests/v1/codegen/rpc/invalid-schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"services": {
"test_service": {
"procedures": {
"rpc_method": {
"input": {
"type": "boolean"
},
"output": {
"type": "object",
"properties": {
"data:3": {
"type": "Date"
},
"data-3": {
"type": "boolean"
}
},
"required": ["data:3"]
},
"errors": {
"not": {}
},
"type": "rpc"
}
}
}
}
}

19 changes: 17 additions & 2 deletions tests/v1/codegen/rpc/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,33 @@
"properties": {
"data": {
"type": "string"
},
"data2": {
"type": "Date"
}
},
"required": ["data"]
"required": ["data", "data2"]
},
"output": {
"type": "object",
"properties": {
"data": {
"type": "string"
},
"data2": {
"type": "Date"
},
"data-3": {
"type": "object",
"properties": {
"data-test": {
"type": "string"
}
},
"required": []
}
},
"required": ["data"]
"required": ["data", "data2", "data-3"]
},
"errors": {
"not": {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class NeedsenumobjectOutputFooOneOf_out_second(BaseModel):


class NeedsenumobjectOutput(BaseModel):
foo: NeedsenumobjectOutputFoo | None = None
foo: NeedsenumobjectOutputFoo | None = Field(default=None)


NeedsenumobjectOutputTypeAdapter: TypeAdapter[NeedsenumobjectOutput] = TypeAdapter(
Expand All @@ -105,11 +105,11 @@ class NeedsenumobjectOutput(BaseModel):


class NeedsenumobjectErrorsFooAnyOf_0(BaseModel):
beep: Literal["err_first"] | None = None
beep: Literal["err_first"] | None = Field(default=None)


class NeedsenumobjectErrorsFooAnyOf_1(BaseModel):
borp: Literal["err_second"] | None = None
borp: Literal["err_second"] | None = Field(default=None)


NeedsenumobjectErrorsFoo = Annotated[
Expand All @@ -121,7 +121,7 @@ class NeedsenumobjectErrorsFooAnyOf_1(BaseModel):


class NeedsenumobjectErrors(RiverError):
foo: NeedsenumobjectErrorsFoo | None = None
foo: NeedsenumobjectErrorsFoo | None = Field(default=None)


NeedsenumobjectErrorsTypeAdapter: TypeAdapter[NeedsenumobjectErrors] = TypeAdapter(
Expand Down
27 changes: 27 additions & 0 deletions tests/v1/codegen/test_invalid_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from io import StringIO

import pytest

from replit_river.codegen.client import schema_to_river_client_codegen


def test_field_name_collision_error() -> None:
"""Test that codegen raises ValueError for field name collisions."""

with pytest.raises(ValueError) as exc_info:
schema_to_river_client_codegen(
read_schema=lambda: open("tests/v1/codegen/rpc/invalid-schema.json"),
target_path="tests/v1/codegen/rpc/generated",
client_name="InvalidClient",
typed_dict_inputs=True,
file_opener=lambda _: StringIO(),
method_filter=None,
protocol_version="v1.1",
)

# Check that the error message matches the expected format for field name collision
error_message = str(exc_info.value)
assert "Field name collision" in error_message
assert "data:3" in error_message
assert "data-3" in error_message
assert "all normalize to the same effective name 'data_3'" in error_message
Loading
Loading