@@ -3673,6 +3673,22 @@ def func(x):
36733673 return tf .identity (picks , name = _TFOUTPUT )
36743674 self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val })
36753675
3676+ @check_opset_min_version (10 , "IsInf" )
3677+ def test_where_with_isinf_condition (self ):
3678+ def func (x , y , z ):
3679+ # Use is_inf as condition to trigger the IsInf code path
3680+ condition = tf .math .is_inf (x )
3681+ result = tf .where (condition , y , z )
3682+ return tf .identity (result , name = _TFOUTPUT )
3683+
3684+ # Create test data with some infinite values
3685+ x_val = np .array ([1.0 , np .inf , 3.0 , - np .inf , 5.0 ], dtype = np .float32 )
3686+ y_val = np .array ([0.0 , 0.0 , 0.0 , 0.0 , 0.0 ], dtype = np .float32 )
3687+ z_val = np .array ([100.0 , 200.0 , 300.0 , 400.0 , 500.0 ], dtype = np .float32 )
3688+
3689+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : x_val , _INPUT1 : y_val , _INPUT2 : z_val })
3690+
3691+
36763692 @check_opset_min_version (9 , "IsNaN" )
36773693 def test_where_isnan (self ):
36783694 x_val = np .array ([1 , 2 , - 3 , float ('nan' ), - 5 , - 6 , float ('nan' ), 8 , 9 , 0 ], dtype = np .float32 )
0 commit comments