@@ -159,9 +159,6 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
159159 data = safe_dtype_range (data , dtype )
160160 dtype = get_equivalent_dtype (dtype , torch .Tensor )
161161
162- # common keyword arguments for recursive calls
163- conv_kwargs = dict (dtype = dtype , device = device , track_meta = track_meta , convert_numeric = convert_numeric )
164-
165162 if isinstance (data , torch .Tensor ):
166163 return _convert_tensor (data ).to (dtype = dtype , device = device , memory_format = torch .contiguous_format )
167164 if isinstance (data , np .ndarray ):
@@ -176,13 +173,22 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
176173 elif (has_cp and isinstance (data , cp_ndarray )) or (convert_numeric and isinstance (data , (float , int , bool ))):
177174 return _convert_tensor (data , dtype = dtype , device = device )
178175 elif isinstance (data , list ):
179- list_ret = [convert_to_tensor (i , ** conv_kwargs ) for i in data ]
176+ list_ret = [
177+ convert_to_tensor (i , dtype = dtype , device = device , track_meta = track_meta , convert_numeric = convert_numeric )
178+ for i in data
179+ ]
180180 return _convert_tensor (list_ret , dtype = dtype , device = device ) if wrap_sequence else list_ret
181181 elif isinstance (data , tuple ):
182- tuple_ret = tuple (convert_to_tensor (i , ** conv_kwargs ) for i in data )
182+ tuple_ret = tuple (
183+ convert_to_tensor (i , dtype = dtype , device = device , track_meta = track_meta , convert_numeric = convert_numeric )
184+ for i in data
185+ )
183186 return _convert_tensor (tuple_ret , dtype = dtype , device = device ) if wrap_sequence else tuple_ret
184187 elif isinstance (data , dict ):
185- return {k : convert_to_tensor (v , ** conv_kwargs ) for k , v in data .items ()}
188+ return {
189+ k : convert_to_tensor (v , dtype = dtype , device = device , track_meta = track_meta , convert_numeric = convert_numeric )
190+ for k , v in data .items ()
191+ }
186192
187193 return data
188194
0 commit comments