@@ -158,6 +158,10 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
158158 if safe :
159159 data = safe_dtype_range (data , dtype )
160160 dtype = get_equivalent_dtype (dtype , torch .Tensor )
161+
162+ # common keyword arguments for recursive calls
163+ conv_kwargs = dict (dtype = dtype , device = device , track_meta = track_meta , convert_numeric = convert_numeric )
164+
161165 if isinstance (data , torch .Tensor ):
162166 return _convert_tensor (data ).to (dtype = dtype , device = device , memory_format = torch .contiguous_format )
163167 if isinstance (data , np .ndarray ):
@@ -172,13 +176,13 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
172176 elif (has_cp and isinstance (data , cp_ndarray )) or (convert_numeric and isinstance (data , (float , int , bool ))):
173177 return _convert_tensor (data , dtype = dtype , device = device )
174178 elif isinstance (data , list ):
175- list_ret = [convert_to_tensor (i , dtype = dtype , device = device , track_meta = track_meta ) for i in data ]
179+ list_ret = [convert_to_tensor (i , ** conv_kwargs ) for i in data ]
176180 return _convert_tensor (list_ret , dtype = dtype , device = device ) if wrap_sequence else list_ret
177181 elif isinstance (data , tuple ):
178- tuple_ret = tuple (convert_to_tensor (i , dtype = dtype , device = device , track_meta = track_meta ) for i in data )
182+ tuple_ret = tuple (convert_to_tensor (i , ** conv_kwargs ) for i in data )
179183 return _convert_tensor (tuple_ret , dtype = dtype , device = device ) if wrap_sequence else tuple_ret
180184 elif isinstance (data , dict ):
181- return {k : convert_to_tensor (v , dtype = dtype , device = device , track_meta = track_meta ) for k , v in data .items ()}
185+ return {k : convert_to_tensor (v , ** conv_kwargs ) for k , v in data .items ()}
182186
183187 return data
184188
0 commit comments