Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed jhmdb weight load issue #6

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configuration/TubeR_CSN152_AVA21.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ CONFIG:
LABEL_PATH: '/xxx/datasets/ava_action_list_v2.1_for_activitynet_2018.pbtxt'
ANNO_PATH: '/xxx/datasets/ava_{}_v21.json'
DATA_PATH: '/xxx/ava/frames/{}/'
EXCLUDE_PATH: '/xxx/datasets/ava_val_excluded_timestamps_v2.1.csv'
NUM_CLASSES: 80
MULTIGRID: False
IMG_SIZE: 256
Expand All @@ -68,7 +69,7 @@ CONFIG:
DS_RATE: 8
TEMP_LEN: 32
SAMPLE_RATE: 2
PRETRAINED: False
PRETRAINED: True
PRETRAIN_BACKBONE_DIR: "/xxx/irCSN_152_ft_kinetics_from_ig65m_f126851907.mat"
PRETRAIN_TRANSFORMER_DIR: "/xxx/detr.pth"
PRETRAINED_PATH: "/xxx/ADTR_CSN_152_ava_21.pth"
Expand Down
3 changes: 2 additions & 1 deletion configuration/TubeR_CSN152_AVA22.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ CONFIG:
LABEL_PATH: '/xxx/datasets/ava_action_list_v2.1_for_activitynet_2018.pbtxt'
ANNO_PATH: '/xxx/datasets/ava_{}_v22.json'
DATA_PATH: '/xxx/ava/frames/{}/'
EXCLUDE_PATH: '/xxx/datasets/ava_val_excluded_timestamps_v2.1.csv'
NUM_CLASSES: 80
MULTIGRID: False
IMG_SIZE: 256
Expand All @@ -68,7 +69,7 @@ CONFIG:
DS_RATE: 8
TEMP_LEN: 32
SAMPLE_RATE: 2
PRETRAINED: False
PRETRAINED: True
PRETRAIN_BACKBONE_DIR: "/xxx/irCSN_152_ft_kinetics_from_ig65m_f126851907.mat"
PRETRAIN_TRANSFORMER_DIR: "/xxx/detr.pth"
PRETRAINED_PATH: "/xxx/ADTR_CSN_152_decode_ava_22.pth"
Expand Down
3 changes: 2 additions & 1 deletion configuration/TubeR_CSN50_AVA21.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ CONFIG:
LABEL_PATH: '/xxx/datasets/ava_action_list_v2.1_for_activitynet_2018.pbtxt'
ANNO_PATH: '/xxx/datasets/ava_{}_v21.json'
DATA_PATH: '/xxx/ava/frames/{}/'
EXCLUDE_PATH: '/xxx/datasets/ava_val_excluded_timestamps_v2.1.csv'
NUM_CLASSES: 80
MULTIGRID: False
IMG_SIZE: 256
Expand All @@ -68,7 +69,7 @@ CONFIG:
DS_RATE: 8
TEMP_LEN: 32
SAMPLE_RATE: 2
PRETRAINED: False
PRETRAINED: True
PRETRAIN_BACKBONE_DIR: "/xxx/irCSN_152_ft_kinetics_from_ig65m_f126851907.mat"
PRETRAIN_TRANSFORMER_DIR: "/xxx/detr.pth"
PRETRAINED_PATH: "/xxx/ADTR_CSN_50_decode_ava_21.pth"
Expand Down
4 changes: 2 additions & 2 deletions configuration/Tuber_CSN152_JHMDB.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ CONFIG:
FRAME_RATE: 2

MODEL:
SINGLE_FRAME: True
SINGLE_FRAME: False
BACKBONE_NAME: CSN-152
TEMPORAL_DS_STRATEGY: decoder
LAST_STRIDE: False
Expand All @@ -69,7 +69,7 @@ CONFIG:
DS_RATE: 8
TEMP_LEN: 32
SAMPLE_RATE: 2
PRETRAINED: False
PRETRAINED: True
PRETRAIN_BACKBONE_DIR: "xxx"
PRETRAIN_TRANSFORMER_DIR: "xxx"
PRETRAINED_PATH: "/xxx/ADTR_CSN_152_jhmdb.pth"
Expand Down
6 changes: 5 additions & 1 deletion datasets/jhmdb_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import datasets.video_transforms as T
from utils.misc import collate_fn
from glob import glob
import random


# Assisting function for finding a good/bad tubelet
Expand Down Expand Up @@ -98,7 +99,10 @@ def check_video(self, vid):

def __getitem__(self, index):
sample_id, frame_id = self.index_to_sample_t[index]
p_t = self.clip_len // 2
if self.mode == 'train':
p_t = random.randint(1, self.clip_len - 2)
else:
p_t = self.clip_len // 2

target = self.load_annotation(sample_id, frame_id, index, p_t)
imgs = self.loadvideo(frame_id, sample_id, target, p_t)
Expand Down
4 changes: 2 additions & 2 deletions evaluates/evaluate_ava.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class STDetectionEvaluater(object):
evaluate(): run evaluation code
'''

def __init__(self, label_path, tiou_thresholds=[0.5], load_from_dataset=False, class_num=60):
def __init__(self, label_path, exclude_path=None, tiou_thresholds=[0.5], load_from_dataset=False, class_num=60):
self.label_path = label_path
categories, class_whitelist = read_labelmap(self.label_path)
self.class_num = class_num
if class_num == 80:
if class_num == 80 and exclude_path is not None:
self.exclude_keys = []
f = open("/xxx/datasets/ava_val_excluded_timestamps_v2.1.csv")
while True:
Expand Down
1 change: 0 additions & 1 deletion train_tuber_jhmdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def main_worker(cfg):
for epoch in range(cfg.CONFIG.TRAIN.START_EPOCH, cfg.CONFIG.TRAIN.EPOCH_NUM):
if cfg.DDP_CONFIG.DISTRIBUTED:
train_sampler.set_epoch(epoch)
time.sleep(1000)
train_tuber_detection(cfg, model, criterion, train_loader, optimizer, epoch, cfg.CONFIG.LOSS_COFS.CLIPS_MAX_NORM, lr_scheduler, writer)
validate_tuber_ucf_detection(cfg, model, criterion, postprocessors, val_loader, 0, writer)

Expand Down
8 changes: 5 additions & 3 deletions utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ def load_detr_weights(model, pretrain_dir, cfg):
elif k.split('.')[1] == 'query_embed':
if not cfg.CONFIG.MODEL.SINGLE_FRAME:
query_size = cfg.CONFIG.MODEL.QUERY_NUM * (cfg.CONFIG.MODEL.TEMP_LEN // cfg.CONFIG.MODEL.DS_RATE)
pretrained_dict.update({k: v[:query_size].repeat(cfg.CONFIG.MODEL.DS_RATE, 1)})
else:
query_size = cfg.CONFIG.MODEL.QUERY_NUM
pretrained_dict.update({k: v[:query_size]})
pretrained_dict.update({k: v[:query_size]})

pretrained_dict_ = {k: v for k, v in pretrained_dict.items() if k in model_dict}
unused_dict = {k: v for k, v in pretrained_dict.items() if not k in model_dict}
Expand Down Expand Up @@ -57,8 +58,9 @@ def deploy_model(model, cfg, is_tuber=True):
# DataParallel will divide and allocate batch_size to all available GPUs
model = torch.nn.DataParallel(model).cuda()

print("loading detr")
load_detr_weights(model, cfg.CONFIG.MODEL.PRETRAIN_TRANSFORMER_DIR, cfg)
if cfg.CONFIG.MODEL.PRETRAINED:
print("loading detr")
load_detr_weights(model, cfg.CONFIG.MODEL.PRETRAIN_TRANSFORMER_DIR, cfg)

return model

Expand Down
14 changes: 2 additions & 12 deletions utils/video_action_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def validate_tuber_detection(cfg, model, criterion, postprocessors, data_loader,
# aggregate files
if cfg.DDP_CONFIG.GPU_WORLD_RANK == 0:
# read results
evaluater = STDetectionEvaluater(cfg.CONFIG.DATA.LABEL_PATH, class_num=cfg.CONFIG.DATA.NUM_CLASSES)
evaluater = STDetectionEvaluater(cfg.CONFIG.DATA.LABEL_PATH, class_num=cfg.CONFIG.DATA.NUM_CLASSES, exclude_path=cfg.CONFIG.DATA.EXCLUDE_PATH)
file_path_lst = [tmp_GT_path.format(cfg.CONFIG.LOG.BASE_PATH, cfg.CONFIG.LOG.RES_DIR, x) for x in range(cfg.DDP_CONFIG.GPU_WORLD_SIZE)]
evaluater.load_GT_from_path(file_path_lst)
file_path_lst = [tmp_path.format(cfg.CONFIG.LOG.BASE_PATH, cfg.CONFIG.LOG.RES_DIR, x) for x in range(cfg.DDP_CONFIG.GPU_WORLD_SIZE)]
Expand Down Expand Up @@ -562,7 +562,7 @@ def validate_tuber_ucf_detection(cfg, model, criterion, postprocessors, data_loa
buff_binary.append(output_b[..., 0])

val_label = targets[bidx]["labels"]
val_category = torch.full((len(val_label), 21), 0)
val_category = torch.full((len(val_label), cfg.CONFIG.DATA.NUM_CLASSES), 0)
for vl in range(len(val_label)):
label = int(val_label[vl])
val_category[vl, label] = 1
Expand Down Expand Up @@ -675,15 +675,5 @@ def validate_tuber_ucf_detection(cfg, model, criterion, postprocessors, data_loa
writer.add_scalar('val/val_mAP_epoch', mAP[0], epoch)
Map_ = mAP[0]

# evaluater = STDetectionEvaluaterSinglePerson(cfg.CONFIG.DATA.LABEL_PATH)
# file_path_lst = [tmp_GT_path.format(cfg.CONFIG.LOG.BASE_PATH, cfg.CONFIG.LOG.RES_DIR, x) for x in range(cfg.DDP_CONFIG.GPU_WORLD_SIZE)]
# evaluater.load_GT_from_path(file_path_lst)
# file_path_lst = [tmp_path.format(cfg.CONFIG.LOG.BASE_PATH, cfg.CONFIG.LOG.RES_DIR, x) for x in range(cfg.DDP_CONFIG.GPU_WORLD_SIZE)]
# evaluater.load_detection_from_path(file_path_lst)
# mAP, metrics = evaluater.evaluate()
# print(metrics)
# print_string = 'person AP: {mAP:.5f}'.format(mAP=mAP[0])
# print(print_string)
# writer.add_scalar('val/val_person_AP_epoch', mAP[0], epoch)
torch.distributed.barrier()
return Map_