Skip to content

Commit

Permalink
polish(nyz): polish code example yapf style
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Mar 1, 2023
1 parent 73835bb commit efab693
Show file tree
Hide file tree
Showing 13 changed files with 14 additions and 27 deletions.
3 changes: 1 addition & 2 deletions chapter1_overview/clip_grad_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]


def clip_grad_norm(
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
def clip_grad_norm(parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
"""
**Overview**:
Implementation of clip_grad_norm <link https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_ link>
Expand Down
3 changes: 1 addition & 2 deletions chapter1_overview/clip_grad_norm_zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]


def clip_grad_norm(
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
def clip_grad_norm(parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
"""
**概述**:
torch.nn.utils.clip_grad_norm 的 PyTorch 版实现。<link https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_ link>
Expand Down
1 change: 0 additions & 1 deletion chapter1_overview/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch
import numpy as np


ppo_policy_data = namedtuple('ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight'])
ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss'])
ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac'])
Expand Down
1 change: 0 additions & 1 deletion chapter1_overview/ppo_zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch
import numpy as np


ppo_policy_data = namedtuple('ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight'])
ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss'])
ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac'])
Expand Down
1 change: 1 addition & 0 deletions chapter2_action/continuous_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


class ContinuousPolicyNetwork(nn.Module):

def __init__(self, obs_shape: int, action_shape: int) -> None:
"""
**Overview**:
Expand Down
1 change: 1 addition & 0 deletions chapter2_action/continuous_tutorial_zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


class ContinuousPolicyNetwork(nn.Module):

def __init__(self, obs_shape: int, action_shape: int) -> None:
"""
**ContinuousPolicyNetwork 定义概述**:
Expand Down
1 change: 1 addition & 0 deletions chapter2_action/discrete_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


class DiscretePolicyNetwork(nn.Module):

def __init__(self, obs_shape: int, action_shape: int) -> None:
"""
**Overview**:
Expand Down
1 change: 1 addition & 0 deletions chapter2_action/discrete_tutorial_zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


class DiscretePolicyNetwork(nn.Module):

def __init__(self, obs_shape: int, action_shape: int) -> None:
"""
**DiscretePolicyNetwork 定义概述**:
Expand Down
9 changes: 2 additions & 7 deletions chapter2_action/hybrid_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


class HybridPolicyNetwork(nn.Module):

def __init__(self, obs_shape: int, action_shape: Dict[str, int]) -> None:
"""
**Overview**:
Expand Down Expand Up @@ -76,13 +77,7 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
# $$\sigma = e^w$$
sigma = torch.exp(log_sigma)
# Return treetensor-type output.
return ttorch.as_tensor({
'action_type': logit,
'action_args': {
'mu': mu,
'sigma': sigma
}
})
return ttorch.as_tensor({'action_type': logit, 'action_args': {'mu': mu, 'sigma': sigma}})


# delimiter
Expand Down
9 changes: 2 additions & 7 deletions chapter2_action/hybrid_tutorial_zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


class HybridPolicyNetwork(nn.Module):

def __init__(self, obs_shape: int, action_shape: Dict[str, int]) -> None:
"""
**HybridPolicyNetwork 概述**:
Expand Down Expand Up @@ -76,13 +77,7 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
# $$\sigma = e^w$$
sigma = torch.exp(log_sigma)
# 返回 treetensor 类型的输出
return ttorch.as_tensor({
'action_type': logit,
'action_args': {
'mu': mu,
'sigma': sigma
}
})
return ttorch.as_tensor({'action_type': logit, 'action_args': {'mu': mu, 'sigma': sigma}})


# delimiter
Expand Down
9 changes: 2 additions & 7 deletions chapter3_obs/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_binary_encoding(bit_num: int):
# Generate a matrix with shape $$2^{B} \times B $$ where B is the bit_num.
# Each row with index n contains the binary representation of n.
location_embedding = []
for n in range(2**bit_num):
for n in range(2 ** bit_num):
s = '0' * (bit_num - len(bin(n)[2:])) + bin(n)[2:]
location_embedding.append(list(int(i) for i in s))
mat = torch.FloatTensor(location_embedding)
Expand All @@ -74,12 +74,7 @@ def test_encoding():
bin_enc = get_binary_encoding(2)
x = torch.arange(4)
y = bin_enc(x)
ground_truth = torch.LongTensor([
[0, 0],
[0, 1],
[1, 0],
[1, 1]
])
ground_truth = torch.LongTensor([[0, 0], [0, 1], [1, 0], [1, 1]])
assert torch.eq(y, ground_truth).all()


Expand Down
1 change: 1 addition & 0 deletions chapter3_obs/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class LinearFunction(Function):
**Overview**:
Implementation of linear (Fully Connected) layer.
"""

@staticmethod
def forward(ctx, input_, weight, bias):
"""
Expand Down
1 change: 1 addition & 0 deletions chapter3_obs/mario_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class OpticalFlowWrapper(gym.Wrapper):
Calculate optical flow using current frame and last frame. The final output contains one channel for current frame and two channels for optical flow.
<link https://en.wikipedia.org/wiki/Optical_flow link>
"""

def __init__(self, env):
"""
**Overview**:
Expand Down

0 comments on commit efab693

Please sign in to comment.