diff --git a/chapter1_overview/clip_grad_norm.py b/chapter1_overview/clip_grad_norm.py index 99e604a..5949fcd 100644 --- a/chapter1_overview/clip_grad_norm.py +++ b/chapter1_overview/clip_grad_norm.py @@ -8,8 +8,7 @@ _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] -def clip_grad_norm( - parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0) -> torch.Tensor: +def clip_grad_norm(parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0) -> torch.Tensor: """ **Overview**: Implementation of clip_grad_norm diff --git a/chapter1_overview/clip_grad_norm_zh.py b/chapter1_overview/clip_grad_norm_zh.py index 7e0e649..155e472 100644 --- a/chapter1_overview/clip_grad_norm_zh.py +++ b/chapter1_overview/clip_grad_norm_zh.py @@ -8,8 +8,7 @@ _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] -def clip_grad_norm( - parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0) -> torch.Tensor: +def clip_grad_norm(parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0) -> torch.Tensor: """ **概述**: torch.nn.utils.clip_grad_norm 的 PyTorch 版实现。 diff --git a/chapter1_overview/ppo.py b/chapter1_overview/ppo.py index 026ec8f..b58349b 100644 --- a/chapter1_overview/ppo.py +++ b/chapter1_overview/ppo.py @@ -13,7 +13,6 @@ import torch import numpy as np - ppo_policy_data = namedtuple('ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight']) ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss']) ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac']) diff --git a/chapter1_overview/ppo_zh.py b/chapter1_overview/ppo_zh.py index 0dafbd1..3d35f18 100644 --- a/chapter1_overview/ppo_zh.py +++ b/chapter1_overview/ppo_zh.py @@ -15,7 +15,6 @@ import torch import numpy as np - ppo_policy_data = namedtuple('ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight']) ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss']) ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac']) diff --git a/chapter2_action/continuous_tutorial.py b/chapter2_action/continuous_tutorial.py index 59946af..5529f01 100644 --- a/chapter2_action/continuous_tutorial.py +++ b/chapter2_action/continuous_tutorial.py @@ -22,6 +22,7 @@ class ContinuousPolicyNetwork(nn.Module): + def __init__(self, obs_shape: int, action_shape: int) -> None: """ **Overview**: diff --git a/chapter2_action/continuous_tutorial_zh.py b/chapter2_action/continuous_tutorial_zh.py index b4601fe..c0e8302 100644 --- a/chapter2_action/continuous_tutorial_zh.py +++ b/chapter2_action/continuous_tutorial_zh.py @@ -22,6 +22,7 @@ class ContinuousPolicyNetwork(nn.Module): + def __init__(self, obs_shape: int, action_shape: int) -> None: """ **ContinuousPolicyNetwork 定义概述**: diff --git a/chapter2_action/discrete_tutorial.py b/chapter2_action/discrete_tutorial.py index d0db9d6..c31e7f9 100644 --- a/chapter2_action/discrete_tutorial.py +++ b/chapter2_action/discrete_tutorial.py @@ -20,6 +20,7 @@ class DiscretePolicyNetwork(nn.Module): + def __init__(self, obs_shape: int, action_shape: int) -> None: """ **Overview**: diff --git a/chapter2_action/discrete_tutorial_zh.py b/chapter2_action/discrete_tutorial_zh.py index 586d787..f299525 100644 --- a/chapter2_action/discrete_tutorial_zh.py +++ b/chapter2_action/discrete_tutorial_zh.py @@ -20,6 +20,7 @@ class DiscretePolicyNetwork(nn.Module): + def __init__(self, obs_shape: int, action_shape: int) -> None: """ **DiscretePolicyNetwork 定义概述**: diff --git a/chapter2_action/hybrid_tutorial.py b/chapter2_action/hybrid_tutorial.py index 64cb420..d153f20 100644 --- a/chapter2_action/hybrid_tutorial.py +++ b/chapter2_action/hybrid_tutorial.py @@ -26,6 +26,7 @@ class HybridPolicyNetwork(nn.Module): + def __init__(self, obs_shape: int, action_shape: Dict[str, int]) -> None: """ **Overview**: @@ -76,13 +77,7 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: # $$\sigma = e^w$$ sigma = torch.exp(log_sigma) # Return treetensor-type output. - return ttorch.as_tensor({ - 'action_type': logit, - 'action_args': { - 'mu': mu, - 'sigma': sigma - } - }) + return ttorch.as_tensor({'action_type': logit, 'action_args': {'mu': mu, 'sigma': sigma}}) # delimiter diff --git a/chapter2_action/hybrid_tutorial_zh.py b/chapter2_action/hybrid_tutorial_zh.py index 339a248..02ba6a5 100644 --- a/chapter2_action/hybrid_tutorial_zh.py +++ b/chapter2_action/hybrid_tutorial_zh.py @@ -26,6 +26,7 @@ class HybridPolicyNetwork(nn.Module): + def __init__(self, obs_shape: int, action_shape: Dict[str, int]) -> None: """ **HybridPolicyNetwork 概述**: @@ -76,13 +77,7 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: # $$\sigma = e^w$$ sigma = torch.exp(log_sigma) # 返回 treetensor 类型的输出 - return ttorch.as_tensor({ - 'action_type': logit, - 'action_args': { - 'mu': mu, - 'sigma': sigma - } - }) + return ttorch.as_tensor({'action_type': logit, 'action_args': {'mu': mu, 'sigma': sigma}}) # delimiter diff --git a/chapter3_obs/encoding.py b/chapter3_obs/encoding.py index 99f9904..259af1c 100644 --- a/chapter3_obs/encoding.py +++ b/chapter3_obs/encoding.py @@ -50,7 +50,7 @@ def get_binary_encoding(bit_num: int): # Generate a matrix with shape $$2^{B} \times B $$ where B is the bit_num. # Each row with index n contains the binary representation of n. location_embedding = [] - for n in range(2**bit_num): + for n in range(2 ** bit_num): s = '0' * (bit_num - len(bin(n)[2:])) + bin(n)[2:] location_embedding.append(list(int(i) for i in s)) mat = torch.FloatTensor(location_embedding) @@ -74,12 +74,7 @@ def test_encoding(): bin_enc = get_binary_encoding(2) x = torch.arange(4) y = bin_enc(x) - ground_truth = torch.LongTensor([ - [0, 0], - [0, 1], - [1, 0], - [1, 1] - ]) + ground_truth = torch.LongTensor([[0, 0], [0, 1], [1, 0], [1, 1]]) assert torch.eq(y, ground_truth).all() diff --git a/chapter3_obs/gradient.py b/chapter3_obs/gradient.py index a49dee8..df23015 100644 --- a/chapter3_obs/gradient.py +++ b/chapter3_obs/gradient.py @@ -22,6 +22,7 @@ class LinearFunction(Function): **Overview**: Implementation of linear (Fully Connected) layer. """ + @staticmethod def forward(ctx, input_, weight, bias): """ diff --git a/chapter3_obs/mario_wrapper.py b/chapter3_obs/mario_wrapper.py index 04f2a78..50866c4 100644 --- a/chapter3_obs/mario_wrapper.py +++ b/chapter3_obs/mario_wrapper.py @@ -19,6 +19,7 @@ class OpticalFlowWrapper(gym.Wrapper): Calculate optical flow using current frame and last frame. The final output contains one channel for current frame and two channels for optical flow. """ + def __init__(self, env): """ **Overview**: