Skip to content

Commit 1839ded

Browse files
committed
polish import
1 parent 4fe9db0 commit 1839ded

File tree

2 files changed

+16
-23
lines changed

2 files changed

+16
-23
lines changed

ding/policy/qtransformer.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,18 @@
11
import copy
22
from collections import namedtuple
3-
from contextlib import nullcontext
4-
from functools import partial
5-
from pathlib import Path
6-
from typing import Any, Dict, List, Optional, Tuple, Union
3+
from typing import Any, Dict, List
74

85
import numpy as np
96
import torch
10-
import torch.distributed as dist
117
import torch.nn.functional as F
12-
from einops import pack, rearrange, repeat, unpack
13-
from einops.layers.torch import Rearrange
14-
from torch import Tensor, einsum, nn
15-
from torch.distributions import Independent, Normal
16-
from torch.nn import Module, ModuleList
17-
from torch.utils.data import DataLoader, Dataset
18-
from torchtyping import TensorType
8+
from einops import pack, rearrange
199

2010
from ding.model import model_wrap
21-
from ding.rl_utils import (get_nstep_return_data, get_train_sample,
22-
qrdqn_nstep_td_data, qrdqn_nstep_td_error,
23-
v_1step_td_data, v_1step_td_error)
2411
from ding.torch_utils import Adam, to_device
2512
from ding.utils import POLICY_REGISTRY
2613
from ding.utils.data import default_collate, default_decollate
2714

2815
from .common_utils import default_preprocess_learn
29-
from .qrdqn import QRDQNPolicy
3016
from .sac import SACPolicy
3117

3218
QIntermediates = namedtuple(

dizoo/d4rl/entry/d4rl_qtransformer_main.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,28 @@
22

33
from ding.config import read_config
44
from ding.entry import serial_pipeline_offline
5-
from ding.model.template.qtransformer import QTransformer
5+
from ding.model import QTransformer
66

77

88
def train(args):
99
# launch from anywhere
10-
config = Path(__file__).absolute().parent.parent / 'config' / args.config
10+
config = Path(__file__).absolute().parent.parent / "config" / args.config
1111
config = read_config(str(config))
12-
config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
13-
model=QTransformer(**config[0].policy.model)
14-
serial_pipeline_offline(config, seed=args.seed,model=model)
12+
config[0].exp_name = config[0].exp_name.replace("0", str(args.seed))
13+
model = QTransformer(**config[0].policy.model)
14+
serial_pipeline_offline(config, seed=args.seed, model=model)
15+
1516

1617
if __name__ == "__main__":
1718
import argparse
19+
1820
parser = argparse.ArgumentParser()
19-
parser.add_argument('--seed', '-s', type=int, default=10)
20-
parser.add_argument('--config', '-c', type=str, default='hopper_medium_expert_qtransformer_config.py')
21+
parser.add_argument("--seed", "-s", type=int, default=10)
22+
parser.add_argument(
23+
"--config",
24+
"-c",
25+
type=str,
26+
default="hopper_medium_expert_qtransformer_config.py",
27+
)
2128
args = parser.parse_args()
2229
train(args)

0 commit comments

Comments
 (0)