-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_finetune.sh
94 lines (89 loc) · 2.4 KB
/
train_finetune.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# ------------------- Model setting -------------------
MODEL=$1
BATCH_SIZE=$2
DATASET=$3
DATASET_ROOT=$4
PRETRAINED_MODEL=$5
WORLD_SIZE=$6
RESUME=$7
# ------------------- Training setting -------------------
OPTIMIZER="adamw"
LRSCHEDULER="cosine"
MIN_LR=1e-6
WEIGHT_DECAY=0.05
if [ $MODEL == "vit_h" ]; then
MAX_EPOCH=50
WP_EPOCH=5
EVAL_EPOCH=5
BASE_LR=0.001
LAYER_DECAY=0.75
DROP_PATH=0.3
elif [ $MODEL == "vit_l" ]; then
MAX_EPOCH=50
WP_EPOCH=5
EVAL_EPOCH=5
BASE_LR=0.001
LAYER_DECAY=0.75
DROP_PATH=0.2
else
MAX_EPOCH=100
WP_EPOCH=5
EVAL_EPOCH=5
BASE_LR=0.0005
LAYER_DECAY=0.65
DROP_PATH=0.1
fi
# ------------------- Dataset config -------------------
if [[ $DATASET == "cifar10" || $DATASET == "cifar100" ]]; then
# Data root
ROOT="none"
# Image config
IMG_SIZE=32
PATCH_SIZE=2
elif [[ $DATASET == "imagenet_1k" || $DATASET == "imagenet_22k" ]]; then
# Data root
ROOT="path/to/imagenet"
# Image config
IMG_SIZE=224
PATCH_SIZE=16
elif [[ $DATASET == "custom" ]]; then
# Data root
ROOT="path/to/custom"
# Image config
IMG_SIZE=224
PATCH_SIZE=16
else
echo "Unknown dataset!!"
exit 1
fi
# ------------------- Training pipeline -------------------
if (( $WORLD_SIZE >= 1 && $WORLD_SIZE <= 8 )); then
python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port 1668 main_finetune.py \
--cuda \
--distributed \
--root ${DATASET_ROOT} \
--dataset ${DATASET} \
--model ${MODEL} \
--batch_size ${BATCH_SIZE} \
--img_size ${IMG_SIZE} \
--patch_size ${PATCH_SIZE} \
--drop_path ${DROP_PATH} \
--max_epoch ${MAX_EPOCH} \
--wp_epoch ${WP_EPOCH} \
--eval_epoch ${EVAL_EPOCH} \
--optimizer ${OPTIMIZER} \
--lr_scheduler ${LRSCHEDULER} \
--base_lr ${BASE_LR} \
--min_lr ${MIN_LR} \
--layer_decay ${LAYER_DECAY} \
--weight_decay ${WEIGHT_DECAY} \
--reprob 0.25 \
--mixup 0.8 \
--cutmix 1.0 \
--resume ${RESUME} \
--pretrained ${PRETRAINED_MODEL}
else
echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
multi-card training mode, which is currently unsupported."
exit 1
fi