Skip to content

Commit 3ef8157

Browse files
authored
change aggregate function type to a dataclass (#2507)
1 parent 9756678 commit 3ef8157

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

packages/py-moose-lib/moose_lib/data_models.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
type Key[T: (str, int)] = T
1616
type JWT[T] = T
1717

18-
type Aggregated[T, agg_func] = Annotated[T, agg_func]
19-
2018

2119
@dataclasses.dataclass # a BaseModel in the annotations will confuse pydantic
2220
class ClickhousePrecision:
@@ -41,8 +39,16 @@ def clickhouse_datetime64(precision: int) -> Type[datetime]:
4139
return Annotated[datetime, ClickhousePrecision(precision=precision)]
4240

4341

44-
class AggregateFunction(BaseModel):
45-
model_config = ConfigDict(arbitrary_types_allowed=True)
42+
def aggregated[T](
43+
result_type: Type[T],
44+
agg_func: str,
45+
param_types: list[type | GenericAlias | _BaseGenericAlias]
46+
) -> Type[T]:
47+
return Annotated[result_type, AggregateFunction(agg_func=agg_func, param_types=param_types)]
48+
49+
50+
@dataclasses.dataclass
51+
class AggregateFunction:
4652
agg_func: str
4753
param_types: list[type | GenericAlias | _BaseGenericAlias]
4854

@@ -120,12 +126,6 @@ def handle_annotation(t: type, md: list[Any]) -> Tuple[type, list[Any]]:
120126
return handle_annotation(t.__value__, md)
121127
if get_origin(t) is Annotated:
122128
return handle_annotation(t.__origin__, md + list(t.__metadata__)) # type: ignore
123-
if get_origin(t) is Aggregated:
124-
args = get_args(t)
125-
agg_func = args[1]
126-
if not isinstance(agg_func, AggregateFunction):
127-
raise ValueError("Pass an AggregateFunction to Aggregated")
128-
return handle_annotation(args[0], md + [agg_func])
129129
return t, md
130130

131131

@@ -285,8 +285,8 @@ def validate(value: Any, _: Any) -> Any:
285285
def is_array_nested_type(data_type: DataType) -> bool:
286286
"""Type guard to check if a data type is Array(Nested(...))."""
287287
return (
288-
isinstance(data_type, ArrayType) and
289-
isinstance(data_type.element_type, Nested)
288+
isinstance(data_type, ArrayType) and
289+
isinstance(data_type.element_type, Nested)
290290
)
291291

292292

0 commit comments

Comments
 (0)