Skip to content

Commit a1f0f73

Browse files
authored
Torch Wrappers to clean out numpy (#294)
* Pre-commit Fix * Pre-commit Fix(2) * Create build_deploy_latest.yml * Create ubuntu.Dockerfile * Update build_deploy_latest.yml * Shift dependency to requirements.txt * - TorchWrapper - Update SAC - Update OffPolicyTrainer * - Update OnPolicyTrainer - Update OffPolicyTrainer - Update A2C - Update PPO - Update VPG - Update Rollout Storage - Update DQNs * Delete build_deploy_latest.yml * Delete ubuntu.Dockerfile * Remove extra modules * Black * Black * Update Noise to use torch.Tensor * Off Policy Base: numpy --> torch.Tensor * DQN bugs * Noise * OffPolicy Base and DDPG * Noise, Parallel Vecenv wrappers, vecnormalize * Rollout Storage PR comments * gym.Wrapper -> genrl.environments.GymWrapper * Pre-commit Fix * Merge and bump version * TD3 Fix * LGTM errors * Syntax error * Bugs * Use isinstance * Vecenv runningmean error * Supress LGTM * Remove numpy array conversion from offpolicy * Supress LGTM errors and update readme * Create yml for LGTM error supression * - LGTM yml lint - Unused Import * YAML lint (2) * LGTM lint (3) * Update .lgtm.yml * Bayesian imports
1 parent 5286095 commit a1f0f73

34 files changed

+248
-232
lines changed

.lgtm.yml

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
queries:
2+
exclude: py/import-and-import-from

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ repos:
66
args: [--exclude=^((examples|docs)/.*)$]
77

88
- repo: https://github.com/timothycrosley/isort
9-
rev: 5.4.2
9+
rev: 4.3.2
1010
hooks:
1111
- id: isort
1212

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,6 @@ trainer.plot(episode_rewards)
142142
- [Gym](https://gym.openai.com/) - Environments
143143
- [Ray](https://github.com/ray-project/ray)
144144
- [OpenAI Baselines](https://github.com/openai/baselines) - Logger
145-
- [Stable Baselines 3](https://github.com/DLR-RM/stable-baselines3): Stable Baselines aims to provide _baselines_ for Deep RL Algorithms. Part of our code (e.g. Rollout Storage) is inspired from Stable Baselines.
145+
- [Stable Baselines 3](https://github.com/DLR-RM/stable-baselines3): Stable Baselines aims to provide _baselines_ for Deep RL Algorithms.
146146
- [pytorch-a2c-ppo-acktr](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail)
147147
- [Deep Contextual Bandits](https://github.com/tensorflow/models/tree/archive/research/deep_contextual_bandits)

genrl/agents/bandits/contextual/common/base_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from abc import ABC, abstractmethod
22
from typing import Dict
33

4-
import torch
5-
import torch.nn as nn
4+
import torch # noqa
5+
import torch.nn as nn # noqa
66
import torch.nn.functional as F
77

88
from genrl.agents.bandits.contextual.common.transition import TransitionDB

genrl/agents/bandits/contextual/common/bayesian.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Dict, Optional, Tuple
22

3-
import torch
4-
import torch.nn as nn
3+
import torch # noqa
4+
import torch.nn as nn # noqa
55
import torch.nn.functional as F
66

77
from genrl.agents.bandits.contextual.common.base_model import Model

genrl/agents/bandits/contextual/common/neural.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Dict
22

3-
import torch
4-
import torch.nn as nn
3+
import torch # noqa
4+
import torch.nn as nn # noqa
55
import torch.nn.functional as F
66

77
from genrl.agents.bandits.contextual.common.base_model import Model

genrl/agents/deep/a2c/a2c.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from typing import Any, Dict
22

33
import gym
4-
import numpy as np
54
import torch
6-
import torch.nn.functional as F
75
import torch.optim as opt
6+
from torch.nn import functional as F
87

98
from genrl.agents.deep.base import OnPolicyAgent
109
from genrl.utils import get_env_properties, get_model, safe_mean
@@ -83,35 +82,33 @@ def _create_model(self) -> None:
8382

8483
if self.noise is not None:
8584
self.noise = self.noise(
86-
np.zeros_like(action_dim), self.noise_std * np.ones_like(action_dim)
85+
torch.zeros(action_dim), self.noise_std * torch.ones(action_dim)
8786
)
8887

8988
self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy)
9089
self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value)
9190

9291
def select_action(
93-
self, state: np.ndarray, deterministic: bool = False
94-
) -> np.ndarray:
92+
self, state: torch.Tensor, deterministic: bool = False
93+
) -> torch.Tensor:
9594
"""Select action given state
9695
9796
Action Selection for On Policy Agents with Actor Critic
9897
9998
Args:
100-
state (:obj:`np.ndarray`): Current state of the environment
99+
state (:obj:`torch.Tensor`): Current state of the environment
101100
deterministic (bool): Should the policy be deterministic or stochastic
102101
103102
Returns:
104-
action (:obj:`np.ndarray`): Action taken by the agent
103+
action (:obj:`torch.Tensor`): Action taken by the agent
105104
value (:obj:`torch.Tensor`): Value of given state
106105
log_prob (:obj:`torch.Tensor`): Log probability of selected action
107106
"""
108-
state = torch.as_tensor(state).float().to(self.device)
109-
110107
# create distribution based on actor output
111108
action, dist = self.ac.get_action(state, deterministic=deterministic)
112109
value = self.ac.get_value(state)
113110

114-
return action.detach().cpu().numpy(), value, dist.log_prob(action).cpu()
111+
return action.detach(), value, dist.log_prob(action).cpu()
115112

116113
def get_traj_loss(self, values: torch.Tensor, dones: torch.Tensor) -> None:
117114
"""Get loss from trajectory traversed by agent during rollouts

genrl/agents/deep/base/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
self.batch_size = batch_size
4747
self.gamma = gamma
4848
self.policy_layers = policy_layers
49+
self.rewards = []
4950
self.value_layers = value_layers
5051
self.lr_policy = lr_policy
5152
self.lr_value = lr_value

genrl/agents/deep/base/offpolicy.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import collections
22
from typing import List
33

4-
import numpy as np
54
import torch
65
from torch.nn import functional as F
76

@@ -155,28 +154,27 @@ def __init__(self, *args, polyak=0.995, **kwargs):
155154
self.doublecritic = False
156155

157156
def select_action(
158-
self, state: np.ndarray, deterministic: bool = True
159-
) -> np.ndarray:
157+
self, state: torch.Tensor, deterministic: bool = True
158+
) -> torch.Tensor:
160159
"""Select action given state
161160
162161
Deterministic Action Selection with Noise
163162
164163
Args:
165-
state (:obj:`np.ndarray`): Current state of the environment
164+
state (:obj:`torch.Tensor`): Current state of the environment
166165
deterministic (bool): Should the policy be deterministic or stochastic
167166
168167
Returns:
169-
action (:obj:`np.ndarray`): Action taken by the agent
168+
action (:obj:`torch.Tensor`): Action taken by the agent
170169
"""
171-
state = torch.as_tensor(state).float()
172170
action, _ = self.ac.get_action(state, deterministic)
173-
action = action.detach().cpu().numpy()
171+
action = action.detach()
174172

175173
# add noise to output from policy network
176174
if self.noise is not None:
177175
action += self.noise()
178176

179-
return np.clip(
177+
return torch.clamp(
180178
action, self.env.action_space.low[0], self.env.action_space.high[0]
181179
)
182180

genrl/agents/deep/base/onpolicy.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from typing import List
2-
3-
import numpy as np
41
import torch
52

63
from genrl.agents.deep.base import BaseAgent
@@ -48,18 +45,18 @@ def update_params(self) -> None:
4845
"""Update parameters of the model"""
4946
raise NotImplementedError
5047

51-
def collect_rewards(self, dones: List[bool], timestep: int):
48+
def collect_rewards(self, dones: torch.Tensor, timestep: int):
5249
"""Helper function to collect rewards
5350
5451
Runs through all the envs and collects rewards accumulated during rollouts
5552
5653
Args:
57-
dones (:obj:`list` of bool): Game over statuses of each environment
54+
dones (:obj:`torch.Tensor`): Game over statuses of each environment
5855
timestep (int): Timestep during rollout
5956
"""
6057
for i, done in enumerate(dones):
6158
if done or timestep == self.rollout_size - 1:
62-
self.rewards.append(self.env.episode_reward[i])
59+
self.rewards.append(self.env.episode_reward[i].detach().clone())
6360
self.env.reset_single_env(i)
6461

6562
def collect_rollouts(self, state: torch.Tensor):
@@ -73,12 +70,12 @@ def collect_rollouts(self, state: torch.Tensor):
7370
7471
Returns:
7572
values (:obj:`torch.Tensor`): Values of states encountered during the rollout
76-
dones (:obj:`list` of bool): Game over statuses of each environment
73+
dones (:obj:`torch.Tensor`): Game over statuses of each environment
7774
"""
7875
for i in range(self.rollout_size):
7976
action, values, old_log_probs = self.select_action(state)
8077

81-
next_state, reward, dones, _ = self.env.step(np.array(action))
78+
next_state, reward, dones, _ = self.env.step(action)
8279

8380
if self.render:
8481
self.env.render()

genrl/agents/deep/ddpg/ddpg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from copy import deepcopy
22
from typing import Any, Dict
33

4-
import numpy as np
4+
import torch
55
import torch.optim as opt
66

77
from genrl.agents import OffPolicyAgentAC
@@ -59,7 +59,7 @@ def _create_model(self) -> None:
5959
)
6060
if self.noise is not None:
6161
self.noise = self.noise(
62-
np.zeros_like(action_dim), self.noise_std * np.ones_like(action_dim)
62+
torch.zeros(action_dim), self.noise_std * torch.ones(action_dim)
6363
)
6464

6565
if isinstance(self.network, str):

genrl/agents/deep/dqn/base.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import math
2+
import random
13
from copy import deepcopy
24
from typing import Any, Dict, List
35

4-
import numpy as np
5-
import torch
6-
import torch.optim as opt
6+
import torch # noqa
7+
import torch.optim as opt # noqa
78

89
from genrl.agents import OffPolicyAgent
910
from genrl.utils import get_env_properties, get_model, safe_mean
@@ -94,38 +95,37 @@ def update_params_before_select_action(self, timestep: int) -> None:
9495
self.epsilon = self.calculate_epsilon_by_frame()
9596
self.logs["epsilon"].append(self.epsilon)
9697

97-
def get_greedy_action(self, state: torch.Tensor) -> np.ndarray:
98+
def get_greedy_action(self, state: torch.Tensor) -> torch.Tensor:
9899
"""Greedy action selection
99100
100101
Args:
101-
state (:obj:`np.ndarray`): Current state of the environment
102+
state (:obj:`torch.Tensor`): Current state of the environment
102103
103104
Returns:
104-
action (:obj:`np.ndarray`): Action taken by the agent
105+
action (:obj:`torch.Tensor`): Action taken by the agent
105106
"""
106-
q_values = self.model(state.unsqueeze(0)).detach().numpy()
107-
action = np.argmax(q_values, axis=-1).squeeze(0)
107+
q_values = self.model(state.unsqueeze(0))
108+
action = torch.argmax(q_values.squeeze(), dim=-1)
108109
return action
109110

110111
def select_action(
111-
self, state: np.ndarray, deterministic: bool = False
112-
) -> np.ndarray:
112+
self, state: torch.Tensor, deterministic: bool = False
113+
) -> torch.Tensor:
113114
"""Select action given state
114115
115116
Epsilon-greedy action-selection
116117
117118
Args:
118-
state (:obj:`np.ndarray`): Current state of the environment
119+
state (:obj:`torch.Tensor`): Current state of the environment
119120
deterministic (bool): Should the policy be deterministic or stochastic
120121
121122
Returns:
122-
action (:obj:`np.ndarray`): Action taken by the agent
123+
action (:obj:`torch.Tensor`): Action taken by the agent
123124
"""
124-
state = torch.as_tensor(state).float()
125125
action = self.get_greedy_action(state)
126126
if not deterministic:
127-
if np.random.rand() < self.epsilon:
128-
action = np.asarray(self.env.sample())
127+
if random.random() < self.epsilon:
128+
action = self.env.sample()
129129
return action
130130

131131
def _reshape_batch(self, batch: List):
@@ -208,7 +208,7 @@ def calculate_epsilon_by_frame(self) -> float:
208208
Exponentially decays exploration rate from max epsilon to min epsilon
209209
The greater the value of epsilon_decay, the slower the decrease in epsilon
210210
"""
211-
return self.min_epsilon + (self.max_epsilon - self.min_epsilon) * np.exp(
211+
return self.min_epsilon + (self.max_epsilon - self.min_epsilon) * math.exp(
212212
-1.0 * self.timestep / self.epsilon_decay
213213
)
214214

genrl/agents/deep/dqn/categorical.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import collections
2-
from typing import List, Tuple
2+
from typing import Tuple
33

4-
import numpy as np
54
import torch
65

76
from genrl.agents.deep.dqn.base import DQN
@@ -67,14 +66,14 @@ def __init__(
6766
if self.create_model:
6867
self._create_model(noisy_layers=self.noisy_layers, num_atoms=self.num_atoms)
6968

70-
def get_greedy_action(self, state: torch.Tensor) -> np.ndarray:
69+
def get_greedy_action(self, state: torch.Tensor) -> torch.Tensor:
7170
"""Greedy action selection
7271
7372
Args:
74-
state (:obj:`np.ndarray`): Current state of the environment
73+
state (:obj:`torch.Tensor`): Current state of the environment
7574
7675
Returns:
77-
action (:obj:`np.ndarray`): Action taken by the agent
76+
action (:obj:`torch.Tensor`): Action taken by the agent
7877
"""
7978
return categorical_greedy_action(self, state)
8079

@@ -91,7 +90,7 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor):
9190
return categorical_q_values(self, states, actions)
9291

9392
def get_target_q_values(
94-
self, next_states: np.ndarray, rewards: List[float], dones: List[bool]
93+
self, next_states: torch.Tensor, rewards: torch.Tensor, dones: torch.Tensor
9594
):
9695
"""Projected Distribution of Q-values
9796

genrl/agents/deep/dqn/utils.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import collections
2-
from typing import List
32

4-
import numpy as np
53
import torch
64

75
from genrl.agents.deep.dqn.base import DQN
@@ -64,25 +62,27 @@ def prioritized_q_loss(agent: DQN, batch: collections.namedtuple):
6462
return loss
6563

6664

67-
def categorical_greedy_action(agent: DQN, state: torch.Tensor) -> np.ndarray:
65+
def categorical_greedy_action(agent: DQN, state: torch.Tensor) -> torch.Tensor:
6866
"""Greedy action selection for Categorical DQN
6967
7068
Args:
7169
agent (:obj:`DQN`): The agent
72-
state (:obj:`np.ndarray`): Current state of the environment
70+
state (:obj:`torch.Tensor`): Current state of the environment
7371
7472
Returns:
75-
action (:obj:`np.ndarray`): Action taken by the agent
73+
action (:obj:`torch.Tensor`): Action taken by the agent
7674
"""
77-
q_value_dist = agent.model(state.unsqueeze(0)).detach().numpy()
75+
q_value_dist = agent.model(state.unsqueeze(0)).detach() # .numpy()
7876
# We need to scale and discretise the Q-value distribution obtained above
79-
q_value_dist = q_value_dist * np.linspace(agent.v_min, agent.v_max, agent.num_atoms)
77+
q_value_dist = q_value_dist * torch.linspace(
78+
agent.v_min, agent.v_max, agent.num_atoms
79+
)
8080
# Then we find the action with the highest Q-values for all discrete regions
8181
# Current shape of the q_value_dist is [1, n_envs, action_dim, num_atoms]
8282
# So we take the sum of all the individual atom q_values and then take argmax
8383
# along action dim to get the optimal action. Since batch_size is 1 for this
8484
# function, we squeeze the first dimension out.
85-
action = np.argmax(q_value_dist.sum(-1), axis=-1).squeeze(0)
85+
action = torch.argmax(q_value_dist.sum(-1), axis=-1).squeeze(0)
8686
return action
8787

8888

@@ -119,9 +119,9 @@ def categorical_q_values(agent: DQN, states: torch.Tensor, actions: torch.Tensor
119119

120120
def categorical_q_target(
121121
agent: DQN,
122-
next_states: np.ndarray,
123-
rewards: List[float],
124-
dones: List[bool],
122+
next_states: torch.Tensor,
123+
rewards: torch.Tensor,
124+
dones: torch.Tensor,
125125
):
126126
"""Projected Distribution of Q-values
127127
@@ -140,8 +140,10 @@ def categorical_q_target(
140140
support = torch.linspace(agent.v_min, agent.v_max, agent.num_atoms)
141141

142142
next_q_value_dist = agent.target_model(next_states) * support
143-
next_actions = torch.argmax(next_q_value_dist.sum(-1), axis=-1)
144-
next_actions = next_actions[:, :, np.newaxis, np.newaxis]
143+
next_actions = (
144+
torch.argmax(next_q_value_dist.sum(-1), axis=-1).unsqueeze(-1).unsqueeze(-1)
145+
)
146+
145147
next_actions = next_actions.expand(
146148
agent.batch_size, agent.env.n_envs, 1, agent.num_atoms
147149
)

0 commit comments

Comments
 (0)