From 92d4013e70bb9f4684c29856317ee69509dd9920 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 16:36:09 +0900 Subject: [PATCH] Add tests --- tests/test_torch_utility.py | 56 +++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/test_torch_utility.py b/tests/test_torch_utility.py index 9c5d0b70..d6412682 100644 --- a/tests/test_torch_utility.py +++ b/tests/test_torch_utility.py @@ -18,6 +18,7 @@ TorchMiniBatch, TorchTrajectoryMiniBatch, View, + copy_recursively, eval_api, get_batch_size, get_device, @@ -168,6 +169,19 @@ def test_to_cpu() -> None: pass +def test_copy_recursively() -> None: + x = torch.rand(10) + y = torch.rand(10) + copy_recursively(x, y) + assert torch.all(x == y) + + x_list = [torch.rand(10), torch.rand(20)] + y_list = [torch.rand(10), torch.rand(20)] + copy_recursively(x_list, y_list) + assert torch.all(x_list[0] == y_list[0]) + assert torch.all(x_list[1] == y_list[1]) + + def test_get_device() -> None: x = torch.rand(10) assert get_device(x) == "cpu" @@ -323,6 +337,29 @@ def test_torch_mini_batch( assert np.all(torch_batch.terminals.numpy() == batch.terminals) assert np.all(torch_batch.intervals.numpy() == batch.intervals) + torch_batch2 = TorchMiniBatch( + observations=torch.zeros_like(torch_batch.observations), + actions=torch.zeros_like(torch_batch.actions), + rewards=torch.zeros_like(torch_batch.rewards), + next_observations=torch.zeros_like(torch_batch.next_observations), + next_actions=torch.zeros_like(torch_batch.next_actions), + returns_to_go=torch.zeros_like(torch_batch.returns_to_go), + terminals=torch.zeros_like(torch_batch.terminals), + intervals=torch.zeros_like(torch_batch.intervals), + device=torch_batch.device, + ) + torch_batch2.copy_(torch_batch) + assert torch.all(torch_batch2.observations == torch_batch.observations) + assert torch.all(torch_batch2.actions == torch_batch.actions) + assert torch.all(torch_batch2.rewards == torch_batch.rewards) + assert torch.all( + torch_batch2.next_observations == torch_batch.next_observations + ) + assert torch.all(torch_batch2.next_actions == torch_batch.next_actions) + assert torch.all(torch_batch2.returns_to_go == torch_batch.returns_to_go) + assert torch.all(torch_batch2.terminals == torch_batch.terminals) + assert torch.all(torch_batch2.intervals == torch_batch.intervals) + @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("length", [32]) @@ -397,6 +434,25 @@ def test_torch_trajectory_mini_batch( assert np.all(torch_batch.terminals.numpy() == batch.terminals) + torch_batch2 = TorchTrajectoryMiniBatch( + observations=torch.zeros_like(torch_batch.observations), + actions=torch.zeros_like(torch_batch.actions), + rewards=torch.zeros_like(torch_batch.rewards), + returns_to_go=torch.zeros_like(torch_batch.returns_to_go), + terminals=torch.zeros_like(torch_batch.terminals), + timesteps=torch.zeros_like(torch_batch.timesteps), + masks=torch.zeros_like(torch_batch.masks), + device=torch_batch.device, + ) + torch_batch2.copy_(torch_batch) + assert torch.all(torch_batch2.observations == torch_batch.observations) + assert torch.all(torch_batch2.actions == torch_batch.actions) + assert torch.all(torch_batch2.rewards == torch_batch.rewards) + assert torch.all(torch_batch2.returns_to_go == torch_batch.returns_to_go) + assert torch.all(torch_batch2.terminals == torch_batch.terminals) + assert torch.all(torch_batch2.timesteps == torch_batch.timesteps) + assert torch.all(torch_batch2.masks == torch_batch.masks) + def test_checkpointer() -> None: fc1 = torch.nn.Linear(100, 100)