diff --git a/tutorials/01-basics/pytorch_basics/main.py b/tutorials/01-basics/pytorch_basics/main.py index 744400c2..c475feda 100644 --- a/tutorials/01-basics/pytorch_basics/main.py +++ b/tutorials/01-basics/pytorch_basics/main.py @@ -93,9 +93,11 @@ # Convert the numpy array to a torch tensor. y = torch.from_numpy(x) -# Convert the torch tensor to a numpy array. +# Convert the torch tensor back to a numpy array. z = y.numpy() +# x and z have identical values inside +assert np.all(x == z) # ================================================================== # # 4. Input pipeline #