Skip to content

Commit

Permalink
Refine type check for is_custom_strong_type (#676)
Browse files Browse the repository at this point in the history
# Description

Refine the type check for is_custom_strong_type static method as a
follow-up to this pull request:
#649

For generic type annotations like list[str], tuple[str, int], dict[str,
typing.Any] would be types.GenericAlias so return False.
For other types that would be passed to issubclass to check if it is the
subclass of CustomStrongTypeConnection.

The try-except block is still kept since it cannot be confidently
removed due to the uncertainty of TypeError that may occur.
TypeError is not expected to happen, but if it does, we will log it for
debugging and return False.

Apart from ut, e2e flow test using below script can pass:
```
  def test(self):
        from promptflow._sdk._pf_client import PFClient

        client = PFClient()
        client.flows.test(flow=r"D:\proj\github\ms\promptflow\examples\flows\standard\gen-docstring")
```

![image](https://github.com/microsoft/promptflow/assets/46446115/dd47eb2a-f935-4793-b1a8-6159ccf07dae)

---------

Co-authored-by: yalu4 <[email protected]>
  • Loading branch information
16oeahr and yalu4 authored Oct 9, 2023
1 parent 0b9d778 commit 793794b
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 5 deletions.
27 changes: 22 additions & 5 deletions src/promptflow/promptflow/contracts/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
# ---------------------------------------------------------

import json
import logging
import types
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Type, TypeVar

from promptflow._constants import CONNECTION_NAME_PROPERTY

from .types import FilePath, PromptTemplate, Secret
from .multimedia import Image
from .types import FilePath, PromptTemplate, Secret

logger = logging.getLogger(__name__)
T = TypeVar("T", bound="Enum")


Expand Down Expand Up @@ -185,14 +188,28 @@ def is_connection_value(val: Any) -> bool:
return val in connections.values() or ConnectionType.is_custom_strong_type(val)

@staticmethod
def is_custom_strong_type(val):
"""Check if the given value is a custom strong type connection."""
def is_custom_strong_type(val: Any) -> bool:
"""Check if the given value is a custom strong type connection.
:param val: The value to check
:type val: Any
:return: Whether the given value is a custom strong type
:rtype: bool
"""

from promptflow._sdk.entities import CustomStrongTypeConnection

# TODO: replace the hotfix "try-except" with a more graceful solution."
val = type(val) if not isinstance(val, type) else val
# Check for instances of GenericAlias (for parameterized generic types like list[str])
if isinstance(val, types.GenericAlias):
return False

try:
return issubclass(val, CustomStrongTypeConnection)
except Exception:
except TypeError as e:
# TypeError is not expected to happen, but if it does, we will log it for debugging and return False.
# The try-except block cannot be confidently removed due to the uncertainty of TypeError that may occur.
logger.warning(f"Failed to check if {val} is a custom strong type: {e}")
return False

@staticmethod
Expand Down
54 changes: 54 additions & 0 deletions src/promptflow/tests/executor/unittests/contracts/test_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Callable, NewType, Optional, Tuple, TypeVar, Union

import pytest

from promptflow._sdk.entities import CustomStrongTypeConnection
from promptflow.contracts.tool import ConnectionType


class MyConnection(CustomStrongTypeConnection):
pass


my_connection = MyConnection(name="my_connection", secrets={"key": "value"})


def some_function():
pass


@pytest.mark.unittest
class TestToolContract:
@pytest.mark.parametrize(
"val, expected_res",
[
(my_connection, True),
(MyConnection, True),
(list, False),
(list[str], False),
(list[int], False),
([1, 2, 3], False),
(float, False),
(int, False),
(5, False),
(str, False),
(some_function, False),
(Union[str, int], False),
# ((int | str), False), # Python 3.10
(tuple, False),
(tuple[str, int], False),
(Tuple[int, ...], False),
(dict[str, Any], False),
({"test1": [1, 2, 3], "test2": [4, 5, 6], "test3": [7, 8, 9]}, False),
(Any, False),
(None, False),
(Optional[str], False),
(TypeVar("T"), False),
(TypeVar, False),
(Callable, False),
(Callable[..., Any], False),
(NewType("MyType", int), False),
],
)
def test_is_custom_strong_type(self, val, expected_res):
assert ConnectionType.is_custom_strong_type(val) == expected_res

0 comments on commit 793794b

Please sign in to comment.