Skip to content

Commit 2edf46c

Browse files
committed
Pass argument in recursive call of convert_to_tensor
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent 11c0ee5 commit 2edf46c

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

monai/utils/type_conversion.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
158158
if safe:
159159
data = safe_dtype_range(data, dtype)
160160
dtype = get_equivalent_dtype(dtype, torch.Tensor)
161+
162+
# common keyword arguments for recursive calls
163+
conv_kwargs = dict(dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)
164+
161165
if isinstance(data, torch.Tensor):
162166
return _convert_tensor(data).to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
163167
if isinstance(data, np.ndarray):
@@ -172,13 +176,13 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
172176
elif (has_cp and isinstance(data, cp_ndarray)) or (convert_numeric and isinstance(data, (float, int, bool))):
173177
return _convert_tensor(data, dtype=dtype, device=device)
174178
elif isinstance(data, list):
175-
list_ret = [convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data]
179+
list_ret = [convert_to_tensor(i, **conv_kwargs) for i in data]
176180
return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret
177181
elif isinstance(data, tuple):
178-
tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data)
182+
tuple_ret = tuple(convert_to_tensor(i, **conv_kwargs) for i in data)
179183
return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret
180184
elif isinstance(data, dict):
181-
return {k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta) for k, v in data.items()}
185+
return {k: convert_to_tensor(v, **conv_kwargs) for k, v in data.items()}
182186

183187
return data
184188

0 commit comments

Comments
 (0)