Skip to content

Commit 2d7c406

Browse files
committed
Reformat
1 parent b898184 commit 2d7c406

File tree

6 files changed

+50
-35
lines changed

6 files changed

+50
-35
lines changed

ding/entry/tests/test_serial_entry.py

+3
Original file line numberDiff line numberDiff line change
@@ -683,8 +683,11 @@ def test_discrete_dt():
683683
assert False, "pipeline fail"
684684
finally:
685685
os.popen('rm -rf cartpole cartpole_dt')
686+
687+
686688
test_discrete_dt()
687689

690+
688691
@pytest.mark.platformtest
689692
@pytest.mark.unittest
690693
def test_td3_bc():

ding/framework/middleware/data_fetcher.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,16 @@ def __new__(cls, *args, **kwargs):
2323

2424
def __init__(self, cfg: EasyDict, dataset: Dataset):
2525
stream = torch.cuda.Stream()
26+
2627
def producer(queue, dataset, batch_size, device, event):
2728
torch.set_num_threads(4)
2829
nonlocal stream
2930
num_gpu = dist.get_world_size()
3031
rank = get_rank()
3132
idx_list = np.random.permutation(len(dataset))
3233
temp_idx_list = []
33-
for i in range(len(dataset)//(batch_size*num_gpu)):
34-
temp_idx_list.extend(idx_list[i+rank*batch_size:i+(rank+1)*batch_size])
34+
for i in range(len(dataset) // (batch_size * num_gpu)):
35+
temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size])
3536
idx_iter = iter(temp_idx_list)
3637

3738
with torch.cuda.stream(stream):
@@ -63,7 +64,7 @@ def producer(queue, dataset, batch_size, device, event):
6364
name='cuda_fetcher_producer'
6465
)
6566

66-
def __call__(self,ctx: "OfflineRLContext"):
67+
def __call__(self, ctx: "OfflineRLContext"):
6768
if not self.producer_thread.is_alive():
6869
time.sleep(5)
6970
self.producer_thread.start()

ding/framework/middleware/functional/data_processor.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,7 @@ def producer(queue, dataset, batch_size, device):
211211
queue = Queue(maxsize=50)
212212
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
213213
producer_thread = Thread(
214-
target=producer,
215-
args=(queue, dataset, cfg.policy.batch_size, device),
216-
name='cuda_fetcher_producer'
214+
target=producer, args=(queue, dataset, cfg.policy.batch_size, device), name='cuda_fetcher_producer'
217215
)
218216

219217
def _fetch(ctx: "OfflineRLContext"):
@@ -263,6 +261,7 @@ def _fetch(ctx: "OfflineRLContext"):
263261
)
264262
ctx.train_data = next(dataloader)
265263
# TODO apply data update (e.g. priority) in offline setting when necessary
264+
266265
return _fetch
267266

268267

ding/model/template/dt.py

+33-21
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def forward(self, x):
9595

9696

9797
class DecisionTransformer(nn.Module):
98+
9899
def __init__(
99100
self,
100101
state_dim,
@@ -121,7 +122,7 @@ def __init__(
121122
self.embed_ln = nn.LayerNorm(h_dim)
122123
self.embed_timestep = nn.Embedding(max_timestep, h_dim)
123124
self.drop = nn.Dropout(drop_p)
124-
125+
125126
self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim))
126127
self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep + 1, self.h_dim))
127128

@@ -140,9 +141,11 @@ def __init__(
140141
# discrete actions
141142
self.embed_action = torch.nn.Embedding(act_dim, h_dim)
142143
use_action_tanh = False # False for discrete actions
143-
self.predict_action = nn.Sequential(*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])))
144+
self.predict_action = nn.Sequential(
145+
*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else []))
146+
)
144147
else:
145-
blocks = [Block(h_dim, input_seq_len+1, n_heads, drop_p) for _ in range(n_blocks)]
148+
blocks = [Block(h_dim, input_seq_len + 1, n_heads, drop_p) for _ in range(n_blocks)]
146149
self.state_encoder = state_encoder
147150
self.embed_rtg = nn.Sequential(nn.Linear(1, h_dim), nn.Tanh())
148151
self.head = nn.Linear(h_dim, act_dim, bias=False)
@@ -161,9 +164,8 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None):
161164

162165
# stack rtg, states and actions and reshape sequence as
163166
# (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...)
164-
t_p = torch.stack(
165-
(returns_embeddings, state_embeddings, action_embeddings), dim=1
166-
).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)
167+
t_p = torch.stack((returns_embeddings, state_embeddings, action_embeddings),
168+
dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)
167169
h = self.embed_ln(t_p)
168170
# transformer and prediction
169171
h = self.transformer(h)
@@ -183,20 +185,24 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None):
183185
state_embeddings = self.state_encoder(
184186
states.reshape(-1, 4, 84, 84).type(torch.float32).contiguous()
185187
) # (batch * block_size, h_dim)
186-
state_embeddings = state_embeddings.reshape(
187-
B, T, self.h_dim
188-
) # (batch, block_size, h_dim)
188+
state_embeddings = state_embeddings.reshape(B, T, self.h_dim) # (batch, block_size, h_dim)
189189
returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32))
190190
action_embeddings = self.embed_action(actions.type(torch.long).squeeze(-1)) # (batch, block_size, h_dim)
191191

192-
token_embeddings = torch.zeros((B, T*3 - int(tar is None), self.h_dim), dtype=torch.float32, device=state_embeddings.device)
193-
token_embeddings[:,::3,:] = returns_embeddings
194-
token_embeddings[:,1::3,:] = state_embeddings
195-
token_embeddings[:,2::3,:] = action_embeddings[:,-T + int(tar is None):,:]
196-
197-
all_global_pos_emb = torch.repeat_interleave(self.global_pos_emb, B, dim=0) # batch_size, traj_length, h_dim
192+
token_embeddings = torch.zeros(
193+
(B, T * 3 - int(tar is None), self.h_dim), dtype=torch.float32, device=state_embeddings.device
194+
)
195+
token_embeddings[:, ::3, :] = returns_embeddings
196+
token_embeddings[:, 1::3, :] = state_embeddings
197+
token_embeddings[:, 2::3, :] = action_embeddings[:, -T + int(tar is None):, :]
198+
199+
all_global_pos_emb = torch.repeat_interleave(
200+
self.global_pos_emb, B, dim=0
201+
) # batch_size, traj_length, h_dim
198202

199-
position_embeddings = torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1)) + self.pos_emb[:, :token_embeddings.shape[1], :]
203+
position_embeddings = torch.gather(
204+
all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1)
205+
) + self.pos_emb[:, :token_embeddings.shape[1], :]
200206

201207
t_p = token_embeddings + position_embeddings
202208

@@ -207,7 +213,7 @@ def forward(self, timesteps, states, actions, returns_to_go, tar=None):
207213

208214
return_preds = None
209215
state_preds = None
210-
action_preds = logits[:, 1::3, :] # only keep predictions from state_embeddings
216+
action_preds = logits[:, 1::3, :] # only keep predictions from state_embeddings
211217

212218
return state_preds, action_preds, return_preds
213219

@@ -227,7 +233,7 @@ def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)):
227233
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
228234
for mn, m in self.named_modules():
229235
for pn, p in m.named_parameters():
230-
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
236+
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
231237

232238
if pn.endswith('bias'):
233239
# all biases will not be decayed
@@ -253,8 +259,14 @@ def configure_optimizers(self, weight_decay, learning_rate, betas=(0.9, 0.95)):
253259

254260
# create the pytorch optimizer object
255261
optim_groups = [
256-
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
257-
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
262+
{
263+
"params": [param_dict[pn] for pn in sorted(list(decay))],
264+
"weight_decay": weight_decay
265+
},
266+
{
267+
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
268+
"weight_decay": 0.0
269+
},
258270
]
259271
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
260-
return optimizer
272+
return optimizer

ding/policy/dt.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -235,17 +235,19 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
235235

236236
if self.t[i] <= self.context_len:
237237
if 'state_mean' not in self._cfg:
238-
timesteps[i] = min(self.t[i],
239-
self._cfg.model.max_timestep) * torch.ones((1, 1), dtype=torch.int64).to(self._device)
238+
timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
239+
(1, 1), dtype=torch.int64
240+
).to(self._device)
240241
else:
241242
timesteps[i] = self.timesteps[i, :self.context_len]
242243
states[i] = self.states[i, :self.context_len]
243244
actions[i] = self.actions[i, :self.context_len]
244245
rewards_to_go[i] = self.rewards_to_go[i, :self.context_len]
245246
else:
246247
if 'state_mean' not in self._cfg:
247-
timesteps[i] = min(self.t[i],
248-
self._cfg.model.max_timestep) * torch.ones((1, 1), dtype=torch.int64).to(self._device)
248+
timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
249+
(1, 1), dtype=torch.int64
250+
).to(self._device)
249251
else:
250252
timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
251253
states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
@@ -267,7 +269,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
267269
else:
268270
act = torch.argmax(logits, axis=1).unsqueeze(1)
269271
for i in data_id:
270-
self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t
272+
self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t
271273
self.t[i] += 1
272274

273275
if self._cuda:

ding/utils/data/dataset.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,7 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
151151
block_size = self.context_len
152152
done_idx = idx + block_size
153153
idx = done_idx - block_size
154-
states = torch.as_tensor(
155-
np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32
156-
).view(block_size, -1)
154+
states = torch.as_tensor(np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32).view(block_size, -1)
157155
actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long)
158156
rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32)
159157
timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64)

0 commit comments

Comments
 (0)