Skip to content

Commit 9b171d4

Browse files
committed
Type fix
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent 2edf46c commit 9b171d4

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

monai/utils/type_conversion.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,6 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
159159
data = safe_dtype_range(data, dtype)
160160
dtype = get_equivalent_dtype(dtype, torch.Tensor)
161161

162-
# common keyword arguments for recursive calls
163-
conv_kwargs = dict(dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)
164-
165162
if isinstance(data, torch.Tensor):
166163
return _convert_tensor(data).to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
167164
if isinstance(data, np.ndarray):
@@ -176,13 +173,22 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any:
176173
elif (has_cp and isinstance(data, cp_ndarray)) or (convert_numeric and isinstance(data, (float, int, bool))):
177174
return _convert_tensor(data, dtype=dtype, device=device)
178175
elif isinstance(data, list):
179-
list_ret = [convert_to_tensor(i, **conv_kwargs) for i in data]
176+
list_ret = [
177+
convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)
178+
for i in data
179+
]
180180
return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret
181181
elif isinstance(data, tuple):
182-
tuple_ret = tuple(convert_to_tensor(i, **conv_kwargs) for i in data)
182+
tuple_ret = tuple(
183+
convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)
184+
for i in data
185+
)
183186
return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret
184187
elif isinstance(data, dict):
185-
return {k: convert_to_tensor(v, **conv_kwargs) for k, v in data.items()}
188+
return {
189+
k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric)
190+
for k, v in data.items()
191+
}
186192

187193
return data
188194

0 commit comments

Comments
 (0)