Skip to content
10 changes: 7 additions & 3 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ def jmespath_to_name(s: str):
model_name=model_name,
jmespath=jmespath,
nrows=nrows,
)
),
"params": {"file": File},
}
# disable prefetch if nrows is set
settings = {"prefetch": 0} if nrows else {}
Expand Down Expand Up @@ -1003,8 +1004,9 @@ def _udf_to_obj(
func: Optional[Union[Callable, UDFObjT]],
params: Union[None, str, Sequence[str]],
output: OutputType,
signal_map,
signal_map: dict[str, Callable],
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just improving typings here.

) -> UDFObjT:
is_batch = target_class.is_input_batched
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be able to check UDF function input arguments for batch UDFs.

is_generator = target_class.is_output_batched
name = self.name or ""

Expand All @@ -1015,7 +1017,9 @@ def _udf_to_obj(
if self._sys:
signals_schema = SignalSchema({"sys": Sys}) | signals_schema

params_schema = signals_schema.slice(sign.params, self._setup)
params_schema = signals_schema.slice(
sign.params, self._setup, is_batch=is_batch
)
Comment on lines +1020 to +1022
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to be able to know if UDF is batching or now when we do checking UDF input argument types.


return target_class._create(sign, params_schema)

Expand Down
4 changes: 2 additions & 2 deletions src/datachain/lib/meta_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401

from datachain.lib.data_model import DataModel # noqa: F401
from datachain.lib.file import File
from datachain.lib.file import TextFile


class UserModel(BaseModel):
Expand Down Expand Up @@ -130,7 +130,7 @@ def read_meta( # noqa: C901
#

def parse_data(
file: File,
file: TextFile,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because File.open returns binary data, and we need text (str) below. Now File types in UDF will be automatically converted based on UDF input param.

data_model=spec,
format=format,
jmespath=jmespath,
Expand Down
72 changes: 57 additions & 15 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def __init__(
if not callable(func):
raise SetupError(key, "value must be function or callable class")

def _init_setup_values(self):
def _init_setup_values(self) -> None:
if self.setup_values is not None:
return self.setup_values
return

res = {}
for key, func in self.setup_func.items():
Expand Down Expand Up @@ -398,7 +398,7 @@ def deserialize(schema: dict[str, Any]) -> "SignalSchema":
return SignalSchema(signals)

@staticmethod
def get_flatten_hidden_fields(schema):
def get_flatten_hidden_fields(schema: dict):
custom_types = schema.get("_custom_types", {})
if not custom_types:
return []
Expand Down Expand Up @@ -464,19 +464,61 @@ def contains_file(self) -> bool:
return False

def slice(
self, keys: Sequence[str], setup: Optional[dict[str, Callable]] = None
self,
params: dict[str, Union[DataType, Any]],
setup: Optional[dict[str, Callable]] = None,
is_batch: bool = False,
) -> "SignalSchema":
Comment on lines 466 to 471
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main changes of this PR goes into this function.

params is now a dictionary with param names as keys and types as values instead of just a list of input params names.

is_batch is needed to be able to check batching UDF input params types.

Let's say we do have a signal foo with type int, then regular UDF might looks like this:

def process(foo: int) -> int:
    ...

and batching UDF looks like this:

def process_batch(foo: list[int]) -> int:
    ...

To be able to check types here we need to know if UDF is batching or not.

# Make new schema that combines current schema and setup signals
setup = setup or {}
setup_no_types = dict.fromkeys(setup.keys(), str)
union = SignalSchema(self.values | setup_no_types)
Comment on lines -471 to -472
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to create new schema, since we are only wanted to know if param with name foo is came from setup or not. We are going to just check if param foo in setup params or not.

# Slice combined schema by keys
schema = {}
for k in keys:
try:
schema[k] = union._find_in_tree(k.split("."))
except SignalResolvingError:
pass
"""
Returns new schema that combines current schema and setup signals.
"""
setup_params = setup.keys() if setup else []
schema: dict[str, DataType] = {}

for param, param_type in params.items():
# This is special case for setup params, they are always treated as strings
if param in setup_params:
schema[param] = str
continue
Comment thread
shcheklein marked this conversation as resolved.

schema_type = self._find_in_tree(param.split("."))

if param_type is Any:
schema[param] = schema_type
continue
Comment on lines +486 to +488
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is happens if UDF function param have no type, e.g.

def process(foo) -> int:
    ...


schema_origin = get_origin(schema_type)
param_origin = get_origin(param_type)
Comment on lines +490 to +491
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will use it couple times below, so just cache the result of get_origin here.


if schema_origin is Union and type(None) in get_args(schema_type):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens if DataSet schema signal is Optional. In this case just for simplicity we might want to allow users to do not specify Optional type in the UDF, so these two declarations will be identical:

def process(foo: Optional[int]) -> int:
    ...

and

def process(foo: int) -> int:
    ...

Note: the most common use case is DataChain.from_csv("path_to_csv_file"), all signals will be optional in this case 🤔 We might want to revisit from_csv method later.

schema_type = get_args(schema_type)[0]
if param_origin is Union and type(None) in get_args(param_type):
param_type = get_args(param_type)[0]

if is_batch:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of batch UDFs these both declarations will be the same:

def process(foo: list[int]) -> int:
    ...

and

def process(foo: list) -> int:
    ...

if param_type is list:
schema[param] = schema_type
continue

if param_origin is not list:
raise SignalResolvingError(param.split("."), "is not a list")

param_type = get_args(param_type)[0]

if param_type == schema_type or (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we do this only for File? does it make sense to make it general?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. For Files we are sure no additional fields are exists in different File type models (File, ImageFile, VideoFile, etc). But in general other models inherited from each other may have additional fields and simple type conversion may fails. We can also check this, but it will complicate things and I am not sure I can see the value here.

We can do it later if needed.

isclass(param_type)
and isclass(schema_type)
and issubclass(param_type, File)
and issubclass(schema_type, File)
):
schema[param] = schema_type
continue
Comment on lines +508 to +515
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UDF input param type should be the same as DataSet signal type.

Special case here: we are converting File signal to TextFile, ImageFile or VideoFile based on UDF input param type signature, and vise-versa, TextFile DataSet signal can be used as regular File type if needed.

Let's say we do have DataSet with File type:

DataChain.from_storage("s3://bucket/")

if we know these files are image files, we can use them in UDF like this:

def process(file: ImageFile) -> ImageMeta:
    return file.get_info()

This might be useful, since File class have no get_info method.

JFYI, other possible way to do this is:

def process(file: File) -> ImageMeta:
    return file.as_image_file().get_info()

Same forks in the opposite way. Let's say, we do have DataSet with ImageFile files:

DataChain.from_storage("s3://bucket/", type="image")

we can use regular File in UDF:

def process(file: File) -> int:
    ...

type will be converted automatically in both ways.


raise SignalResolvingError(
param.split("."),
f"types mismatch: {param_type} != {schema_type}",
)

return SignalSchema(schema, setup)

def row_to_features(
Expand Down
3 changes: 3 additions & 0 deletions src/datachain/lib/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def process(self, file) -> list[float]:
```
"""

is_input_batched = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use the same is_output_batch? (may be rename it to `is_batched)? it seems it always has the same value

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For generators input will not be batched (single row), but output is batched, thus we need to have separate params for these.

is_output_batched = False
prefetch: int = 0

Expand Down Expand Up @@ -395,6 +396,7 @@ def _prepare_rows(udf_inputs) -> "abc.Generator[Sequence[Any], None, None]":
class BatchMapper(UDFBase):
"""Inherit from this class to pass to `DataChain.batch_map()`."""

is_input_batched = True
is_output_batched = True

def run(
Expand Down Expand Up @@ -481,6 +483,7 @@ def _process_row(row):
class Aggregator(UDFBase):
"""Inherit from this class to pass to `DataChain.agg()`."""

is_input_batched = True
is_output_batched = True

def run(
Expand Down
26 changes: 17 additions & 9 deletions src/datachain/lib/udf_signature.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from collections.abc import Generator, Iterator, Sequence
from dataclasses import dataclass
from typing import Callable, Union, get_args, get_origin
from typing import Any, Callable, Union, get_args, get_origin

from datachain.lib.data_model import DataType, DataTypeNames, is_chain_type
from datachain.lib.signal_schema import SignalSchema
Expand All @@ -18,7 +18,7 @@ def __init__(self, chain: str, msg):
@dataclass
class UdfSignature:
func: Union[Callable, UDFBase]
params: Sequence[str]
params: dict[str, Union[DataType, Any]]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UDF signature params is now a dictionary with param names as keys and param types as values, instead of just a list of param names.

output_schema: SignalSchema

DEFAULT_RETURN_TYPE = str
Expand Down Expand Up @@ -58,15 +58,23 @@ def parse(
if not isinstance(udf_func, UDFBase) and not callable(udf_func):
raise UdfSignatureError(chain, f"UDF '{udf_func}' is not callable")

func_params_map_sign, func_outs_sign, is_iterator = (
UdfSignature._func_signature(chain, udf_func)
func_params_map_sign, func_outs_sign, is_iterator = cls._func_signature(
chain, udf_func
)
Comment on lines +61 to 63
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why not to use self here 👀


udf_params: dict[str, Union[DataType, Any]] = {}
if params:
udf_params = [params] if isinstance(params, str) else params
elif not func_params_map_sign:
udf_params = []
else:
udf_params = list(func_params_map_sign.keys())
udf_params = (
{params: Any} if isinstance(params, str) else dict.fromkeys(params, Any)
)
Comment on lines +67 to +69
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Any type here for params with unknown type, since we don't have an Unknown type in Python 👀

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can it be None? what is parameter type is type annotation is missing ... e.g. how does mypy treat it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the case when params is set in map chain method. Here we are not checking UDF signature param types, e.g.:

dc.map(lambda file: file.parse_info(), params="file")

or

def process_signals(signal1, signal2) -> int:
    return signal1 + signal2

dc.map(process_signals, params=["signal1", "signal2"])

Answering your question, if parameter type is missing, inspect will return it as inspect.Parameter.empty (see below). In this case we are returning it's type as Any (since where is no Unknown type in Python).

elif func_params_map_sign:
udf_params = {
param: (
param_type if param_type is not inspect.Parameter.empty else Any
)
for param, param_type in func_params_map_sign.items()
}

if output:
udf_output_map = UdfSignature._validate_output(
chain, signal_name, func, func_outs_sign, output
Expand Down
2 changes: 1 addition & 1 deletion tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ def __init__(self):
self.value = MyMapper.DEFAULT_VALUE
self._had_teardown = False

def process(self, *args) -> int:
def process(self, key) -> int:
return self.value

def setup(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/lib/test_datachain_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self):
self.value = MyMapper.DEFAULT_VALUE
self._had_teardown = False

def process(self, *args) -> int:
def process(self, key) -> int:
return self.value

def setup(self):
Expand All @@ -40,7 +40,7 @@ def __init__(self):
self._had_bootstrap = False
self._had_teardown = False

def __call__(self, *args):
def __call__(self, key):
return None

def bootstrap(self):
Expand Down
Loading