diff --git a/download_weights.sh b/download_weights.sh index 00c11b0..3b23e2e 100644 --- a/download_weights.sh +++ b/download_weights.sh @@ -1,9 +1,4 @@ mkdir weights -wget https://pjreddie.com/media/files/yolov3-tiny.weights -P weights -wget https://pjreddie.com/media/files/yolov3.weights -P weights -wget https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4.weights -P weights -wget https://github.com/AlexeyAB/darknet/releases/download/yolov4/yolov4-tiny.weights -P weights -wget https://github.com/ultralytics/yolov3/releases/download/v8/yolov3-spp.weights -P weights wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov5nu.pt -P weights wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov5su.pt -P weights wget https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov5mu.pt -P weights diff --git a/src/test.py b/src/test.py index 6245fd1..c0864a3 100644 --- a/src/test.py +++ b/src/test.py @@ -1,12 +1,19 @@ -import time +import os import torch import torchvision -import torchvision.transforms.functional as vF -import matplotlib.pyplot as plt from models import * +weight_paths = { + 'yolov3-tiny' : 'https://pjreddie.com/media/files/yolov3-tiny.weights', + 'yolov3' : 'https://pjreddie.com/media/files/yolov3.weights', + 'yolov3-spp' : 'https://github.com/ultralytics/yolov3/releases/download/v8/yolov3-spp.weights', + 'yolov4' : 'https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4.weights', + 'yolov4-tiny' : 'https://github.com/AlexeyAB/darknet/releases/download/yolov4/yolov4-tiny.weights', +} -torch.set_printoptions(5) +def download_if_not_exist(model_type: str, filepath: str): + if not os.path.exists(filepath): + torch.hub.download_url_to_file(weight_paths[model_type], filepath) def load_from_darknet(net: Union[Yolov3, Yolov3Tiny, Yolov4, Yolov4Tiny], weights_path: str): @@ -126,6 +133,8 @@ def params2(): def test(type: str, size: str = ''): + os.makedirs('../weights', exist_ok=True) + match type: case 'yolov3' : net = Yolov3(80, False).eval() case 'yolov3-spp': net = Yolov3(80, True).eval() @@ -147,7 +156,9 @@ def test(type: str, size: str = ''): has_obj = False elif 'yolov3' in type or 'yolov4' in type : - load_from_darknet(net, '../weights/{}.weights'.format(type)) + filepath = '../weights/{}.weights'.format(type) + download_if_not_exist(type, filepath) + load_from_darknet(net, filepath) elif type == 'yolov6': load_from_yolov6_official(net, "../weights/yolov6{}.pt".format(size))