-
Notifications
You must be signed in to change notification settings - Fork 140
Check UDF params types in SignalSchema #973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bd91d6a
769cbe4
37a2be7
7b44b3f
56165c2
f6ae293
d7da6a1
7ef96e4
848d452
9b31aba
d362de0
237e3e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -629,7 +629,8 @@ def jmespath_to_name(s: str): | |
| model_name=model_name, | ||
| jmespath=jmespath, | ||
| nrows=nrows, | ||
| ) | ||
| ), | ||
| "params": {"file": File}, | ||
| } | ||
| # disable prefetch if nrows is set | ||
| settings = {"prefetch": 0} if nrows else {} | ||
|
|
@@ -1003,8 +1004,9 @@ def _udf_to_obj( | |
| func: Optional[Union[Callable, UDFObjT]], | ||
| params: Union[None, str, Sequence[str]], | ||
| output: OutputType, | ||
| signal_map, | ||
| signal_map: dict[str, Callable], | ||
| ) -> UDFObjT: | ||
| is_batch = target_class.is_input_batched | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be able to check UDF function input arguments for batch UDFs. |
||
| is_generator = target_class.is_output_batched | ||
| name = self.name or "" | ||
|
|
||
|
|
@@ -1015,7 +1017,9 @@ def _udf_to_obj( | |
| if self._sys: | ||
| signals_schema = SignalSchema({"sys": Sys}) | signals_schema | ||
|
|
||
| params_schema = signals_schema.slice(sign.params, self._setup) | ||
| params_schema = signals_schema.slice( | ||
| sign.params, self._setup, is_batch=is_batch | ||
| ) | ||
|
Comment on lines
+1020
to
+1022
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to be able to know if UDF is batching or now when we do checking UDF input argument types. |
||
|
|
||
| return target_class._create(sign, params_schema) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,7 @@ | |
| from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401 | ||
|
|
||
| from datachain.lib.data_model import DataModel # noqa: F401 | ||
| from datachain.lib.file import File | ||
| from datachain.lib.file import TextFile | ||
|
|
||
|
|
||
| class UserModel(BaseModel): | ||
|
|
@@ -130,7 +130,7 @@ def read_meta( # noqa: C901 | |
| # | ||
|
|
||
| def parse_data( | ||
| file: File, | ||
| file: TextFile, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because |
||
| data_model=spec, | ||
| format=format, | ||
| jmespath=jmespath, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -154,9 +154,9 @@ def __init__( | |
| if not callable(func): | ||
| raise SetupError(key, "value must be function or callable class") | ||
|
|
||
| def _init_setup_values(self): | ||
| def _init_setup_values(self) -> None: | ||
| if self.setup_values is not None: | ||
| return self.setup_values | ||
| return | ||
|
|
||
| res = {} | ||
| for key, func in self.setup_func.items(): | ||
|
|
@@ -398,7 +398,7 @@ def deserialize(schema: dict[str, Any]) -> "SignalSchema": | |
| return SignalSchema(signals) | ||
|
|
||
| @staticmethod | ||
| def get_flatten_hidden_fields(schema): | ||
| def get_flatten_hidden_fields(schema: dict): | ||
| custom_types = schema.get("_custom_types", {}) | ||
| if not custom_types: | ||
| return [] | ||
|
|
@@ -464,19 +464,61 @@ def contains_file(self) -> bool: | |
| return False | ||
|
|
||
| def slice( | ||
| self, keys: Sequence[str], setup: Optional[dict[str, Callable]] = None | ||
| self, | ||
| params: dict[str, Union[DataType, Any]], | ||
| setup: Optional[dict[str, Callable]] = None, | ||
| is_batch: bool = False, | ||
| ) -> "SignalSchema": | ||
|
Comment on lines
466
to
471
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Main changes of this PR goes into this function.
Let's say we do have a signal def process(foo: int) -> int:
...and batching UDF looks like this: def process_batch(foo: list[int]) -> int:
...To be able to check types here we need to know if UDF is batching or not. |
||
| # Make new schema that combines current schema and setup signals | ||
| setup = setup or {} | ||
| setup_no_types = dict.fromkeys(setup.keys(), str) | ||
| union = SignalSchema(self.values | setup_no_types) | ||
|
Comment on lines
-471
to
-472
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to create new schema, since we are only wanted to know if param with name |
||
| # Slice combined schema by keys | ||
| schema = {} | ||
| for k in keys: | ||
| try: | ||
| schema[k] = union._find_in_tree(k.split(".")) | ||
| except SignalResolvingError: | ||
| pass | ||
| """ | ||
| Returns new schema that combines current schema and setup signals. | ||
| """ | ||
| setup_params = setup.keys() if setup else [] | ||
| schema: dict[str, DataType] = {} | ||
|
|
||
| for param, param_type in params.items(): | ||
| # This is special case for setup params, they are always treated as strings | ||
| if param in setup_params: | ||
| schema[param] = str | ||
| continue | ||
|
shcheklein marked this conversation as resolved.
|
||
|
|
||
| schema_type = self._find_in_tree(param.split(".")) | ||
|
|
||
| if param_type is Any: | ||
| schema[param] = schema_type | ||
| continue | ||
|
Comment on lines
+486
to
+488
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is happens if UDF function param have no type, e.g. |
||
|
|
||
| schema_origin = get_origin(schema_type) | ||
| param_origin = get_origin(param_type) | ||
|
Comment on lines
+490
to
+491
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will use it couple times below, so just cache the result of |
||
|
|
||
| if schema_origin is Union and type(None) in get_args(schema_type): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This happens if DataSet schema signal is def process(foo: Optional[int]) -> int:
...and def process(foo: int) -> int:
...Note: the most common use case is |
||
| schema_type = get_args(schema_type)[0] | ||
| if param_origin is Union and type(None) in get_args(param_type): | ||
| param_type = get_args(param_type)[0] | ||
|
|
||
| if is_batch: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case of batch UDFs these both declarations will be the same: def process(foo: list[int]) -> int:
...and def process(foo: list) -> int:
... |
||
| if param_type is list: | ||
| schema[param] = schema_type | ||
| continue | ||
|
|
||
| if param_origin is not list: | ||
| raise SignalResolvingError(param.split("."), "is not a list") | ||
|
|
||
| param_type = get_args(param_type)[0] | ||
|
|
||
| if param_type == schema_type or ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we do this only for File? does it make sense to make it general?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. For Files we are sure no additional fields are exists in different File type models (File, ImageFile, VideoFile, etc). But in general other models inherited from each other may have additional fields and simple type conversion may fails. We can also check this, but it will complicate things and I am not sure I can see the value here. We can do it later if needed. |
||
| isclass(param_type) | ||
| and isclass(schema_type) | ||
| and issubclass(param_type, File) | ||
| and issubclass(schema_type, File) | ||
| ): | ||
| schema[param] = schema_type | ||
| continue | ||
|
Comment on lines
+508
to
+515
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. UDF input param type should be the same as DataSet signal type. Special case here: we are converting Let's say we do have DataSet with DataChain.from_storage("s3://bucket/")if we know these files are image files, we can use them in UDF like this: def process(file: ImageFile) -> ImageMeta:
return file.get_info()This might be useful, since JFYI, other possible way to do this is: def process(file: File) -> ImageMeta:
return file.as_image_file().get_info()Same forks in the opposite way. Let's say, we do have DataSet with DataChain.from_storage("s3://bucket/", type="image")we can use regular def process(file: File) -> int:
...type will be converted automatically in both ways. |
||
|
|
||
| raise SignalResolvingError( | ||
| param.split("."), | ||
| f"types mismatch: {param_type} != {schema_type}", | ||
| ) | ||
|
|
||
| return SignalSchema(schema, setup) | ||
|
|
||
| def row_to_features( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,6 +159,7 @@ def process(self, file) -> list[float]: | |
| ``` | ||
| """ | ||
|
|
||
| is_input_batched = False | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not use the same
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For generators input will not be batched (single row), but output is batched, thus we need to have separate params for these. |
||
| is_output_batched = False | ||
| prefetch: int = 0 | ||
|
|
||
|
|
@@ -395,6 +396,7 @@ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]": | |
| class BatchMapper(UDFBase): | ||
| """Inherit from this class to pass to `DataChain.batch_map()`.""" | ||
|
|
||
| is_input_batched = True | ||
| is_output_batched = True | ||
|
|
||
| def run( | ||
|
|
@@ -481,6 +483,7 @@ def _process_row(row): | |
| class Aggregator(UDFBase): | ||
| """Inherit from this class to pass to `DataChain.agg()`.""" | ||
|
|
||
| is_input_batched = True | ||
| is_output_batched = True | ||
|
|
||
| def run( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| import inspect | ||
| from collections.abc import Generator, Iterator, Sequence | ||
| from dataclasses import dataclass | ||
| from typing import Callable, Union, get_args, get_origin | ||
| from typing import Any, Callable, Union, get_args, get_origin | ||
|
|
||
| from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type | ||
| from datachain.lib.signal_schema import SignalSchema | ||
|
|
@@ -18,7 +18,7 @@ def __init__(self, chain: str, msg): | |
| @dataclass | ||
| class UdfSignature: | ||
| func: Union[Callable, UDFBase] | ||
| params: Sequence[str] | ||
| params: dict[str, Union[DataType, Any]] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. UDF signature params is now a dictionary with param names as keys and param types as values, instead of just a list of param names. |
||
| output_schema: SignalSchema | ||
|
|
||
| DEFAULT_RETURN_TYPE = str | ||
|
|
@@ -58,15 +58,23 @@ def parse( | |
| if not isinstance(udf_func, UDFBase) and not callable(udf_func): | ||
| raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable") | ||
|
|
||
| func_params_map_sign, func_outs_sign, is_iterator = ( | ||
| UdfSignature._func_signature(chain, udf_func) | ||
| func_params_map_sign, func_outs_sign, is_iterator = cls._func_signature( | ||
| chain, udf_func | ||
| ) | ||
|
Comment on lines
+61
to
63
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why not to use |
||
|
|
||
| udf_params: dict[str, Union[DataType, Any]] = {} | ||
| if params: | ||
| udf_params = [params] if isinstance(params, str) else params | ||
| elif not func_params_map_sign: | ||
| udf_params = [] | ||
| else: | ||
| udf_params = list(func_params_map_sign.keys()) | ||
| udf_params = ( | ||
| {params: Any} if isinstance(params, str) else dict.fromkeys(params, Any) | ||
| ) | ||
|
Comment on lines
+67
to
+69
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can it be None? what is parameter type is type annotation is missing ... e.g. how does mypy treat it?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the case when dc.map(lambda file: file.parse_info(), params="file")or def process_signals(signal1, signal2) -> int:
return signal1 + signal2
dc.map(process_signals, params=["signal1", "signal2"])Answering your question, if parameter type is missing, |
||
| elif func_params_map_sign: | ||
| udf_params = { | ||
| param: ( | ||
| param_type if param_type is not inspect.Parameter.empty else Any | ||
| ) | ||
| for param, param_type in func_params_map_sign.items() | ||
| } | ||
|
|
||
| if output: | ||
| udf_output_map = UdfSignature._validate_output( | ||
| chain, signal_name, func, func_outs_sign, output | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just improving typings here.