Skip to content

Commit

Permalink
Reland union type (#5900)
Browse files Browse the repository at this point in the history
* Reapply "Add union link connection type support (#5806)" (#5889)

This reverts commit bf9a90a.

* Fix union type breaks existing type workarounds

* Add non-string test

* Add tests for hacks and non-string types

* Support python versions lower than 3.11
  • Loading branch information
webfiltered authored Dec 4, 2024
1 parent 4827244 commit 4e402b1
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 3 deletions.
39 changes: 39 additions & 0 deletions comfy_execution/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations


def validate_node_input(
received_type: str, input_type: str, strict: bool = False
) -> bool:
"""
received_type and input_type are both strings of the form "T1,T2,...".
If strict is True, the input_type must contain the received_type.
For example, if received_type is "STRING" and input_type is "STRING,INT",
this will return True. But if received_type is "STRING,INT" and input_type is
"INT", this will return False.
If strict is False, the input_type must have overlap with the received_type.
For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
this will return True.
Supports pre-union type extension behaviour of ``__ne__`` overrides.
"""
# If the types are exactly the same, we can return immediately
# Use pre-union behaviour: inverse of `__ne__`
if not received_type != input_type:
return True

# Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str):
return False

# Split the type strings into sets for comparison
received_types = set(t.strip() for t in received_type.split(","))
input_types = set(t.strip() for t in input_type.split(","))

if strict:
# In strict mode, all received types must be in the input types
return received_types.issubset(input_types)
else:
# In non-strict mode, there must be at least one type in common
return len(received_types.intersection(input_types)) > 0
6 changes: 3 additions & 3 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input
from comfy.cli_args import args

class ExecutionResult(Enum):
Expand Down Expand Up @@ -527,7 +528,6 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
comfy.model_management.unload_all_models()



def validate_inputs(prompt, item, validated):
unique_id = item
if unique_id in validated:
Expand Down Expand Up @@ -589,8 +589,8 @@ def validate_inputs(prompt, item, validated):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and received_type != type_input:
details = f"{x}, {received_type} != {type_input}"
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
error = {
"type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes",
Expand Down
119 changes: 119 additions & 0 deletions tests-unit/execution_test/validate_node_input_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pytest
from comfy_execution.validation import validate_node_input


def test_exact_match():
"""Test cases where types match exactly"""
assert validate_node_input("STRING", "STRING")
assert validate_node_input("STRING,INT", "STRING,INT")
assert validate_node_input("INT,STRING", "STRING,INT") # Order shouldn't matter


def test_strict_mode():
"""Test strict mode validation"""
# Should pass - received type is subset of input type
assert validate_node_input("STRING", "STRING,INT", strict=True)
assert validate_node_input("INT", "STRING,INT", strict=True)
assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True)

# Should fail - received type is not subset of input type
assert not validate_node_input("STRING,INT", "STRING", strict=True)
assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True)
assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True)


def test_non_strict_mode():
"""Test non-strict mode validation (default behavior)"""
# Should pass - types have overlap
assert validate_node_input("STRING,BOOLEAN", "STRING,INT")
assert validate_node_input("STRING,INT", "INT,BOOLEAN")
assert validate_node_input("STRING", "STRING,INT")

# Should fail - no overlap in types
assert not validate_node_input("BOOLEAN", "STRING,INT")
assert not validate_node_input("FLOAT", "STRING,INT")
assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT")


def test_whitespace_handling():
"""Test that whitespace is handled correctly"""
assert validate_node_input("STRING, INT", "STRING,INT")
assert validate_node_input("STRING,INT", "STRING, INT")
assert validate_node_input(" STRING , INT ", "STRING,INT")
assert validate_node_input("STRING,INT", " STRING , INT ")


def test_empty_strings():
"""Test behavior with empty strings"""
assert validate_node_input("", "")
assert not validate_node_input("STRING", "")
assert not validate_node_input("", "STRING")


def test_single_vs_multiple():
"""Test single type against multiple types"""
assert validate_node_input("STRING", "STRING,INT,BOOLEAN")
assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False)
assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True)


def test_non_string():
"""Test non-string types"""
obj1 = object()
obj2 = object()
assert validate_node_input(obj1, obj1)
assert not validate_node_input(obj1, obj2)


class NotEqualsOverrideTest(str):
"""Test class for ``__ne__`` override."""

def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if self == "LONGER_THAN_2":
return not len(value) > 2
raise TypeError("This is a class for unit tests only.")


def test_ne_override():
"""Test ``__ne__`` any override"""
any = NotEqualsOverrideTest("*")
invalid_type = "INVALID_TYPE"
obj = object()
assert validate_node_input(any, any)
assert validate_node_input(any, invalid_type)
assert validate_node_input(any, obj)
assert validate_node_input(any, {})
assert validate_node_input(any, [])
assert validate_node_input(any, [1, 2, 3])


def test_ne_custom_override():
"""Test ``__ne__`` custom override"""
special = NotEqualsOverrideTest("LONGER_THAN_2")

assert validate_node_input(special, special)
assert validate_node_input(special, "*")
assert validate_node_input(special, "INVALID_TYPE")
assert validate_node_input(special, [1, 2, 3])

# Should fail
assert not validate_node_input(special, [1, 2])
assert not validate_node_input(special, "TY")


@pytest.mark.parametrize(
"received,input_type,strict,expected",
[
("STRING", "STRING", False, True),
("STRING,INT", "STRING,INT", False, True),
("STRING", "STRING,INT", True, True),
("STRING,INT", "STRING", True, False),
("BOOLEAN", "STRING,INT", False, False),
("STRING,BOOLEAN", "STRING,INT", False, True),
],
)
def test_parametrized_cases(received, input_type, strict, expected):
"""Parametrized test cases for various scenarios"""
assert validate_node_input(received, input_type, strict) == expected

0 comments on commit 4e402b1

Please sign in to comment.