From 52b90b0c8b63a9ebb164c11e4ac18af9ac17532b Mon Sep 17 00:00:00 2001 From: marcojob <44396071+marcojob@users.noreply.github.com> Date: Thu, 19 Dec 2024 16:15:51 +0100 Subject: [PATCH] unittest: Add test that aims to test the inference time of the approach --- .../depth_anything_v2/model_helper.py | 59 ++++++++++--------- radarmeetsvision/utils.py | 4 +- tests/test_metric_depth_network.py | 22 +++++++ 3 files changed, 54 insertions(+), 31 deletions(-) diff --git a/radarmeetsvision/metric_depth_network/depth_anything_v2/model_helper.py b/radarmeetsvision/metric_depth_network/depth_anything_v2/model_helper.py index 02e2349..42bd1a7 100644 --- a/radarmeetsvision/metric_depth_network/depth_anything_v2/model_helper.py +++ b/radarmeetsvision/metric_depth_network/depth_anything_v2/model_helper.py @@ -20,36 +20,37 @@ def get_model(pretrained_from, use_depth_prior, encoder, max_depth, output_channels): model = DepthAnythingV2(**{**model_configs[encoder], 'max_depth': max_depth, 'use_depth_prior': use_depth_prior, 'output_channels': output_channels}) + state_dict = model.state_dict() if pretrained_from: logger.info("Loading pretrained model") - pretrained_dict = torch.load(pretrained_from, map_location='cpu') - if 'model' in pretrained_dict.keys(): - pretrained_dict = pretrained_dict['model'] - - if use_depth_prior: - logger.info("Using depth prior") - - pretrained_weights = pretrained_dict['pretrained.patch_embed.proj.weight'] - if pretrained_weights.shape[1] < 4: - logger.info("Appending weights to pretrained network for input channels") - new_channel_weights = torch.randn(pretrained_weights.shape[0], 1, pretrained_weights.shape[2], pretrained_weights.shape[3]) * 0.01 - new_weights = torch.cat((pretrained_weights, new_channel_weights), dim=1) - pretrained_dict['pretrained.patch_embed.proj.weight'] = new_weights - - if output_channels > 1: - logger.info(f"Using {output_channels} output channels") - pretrained_weights = pretrained_dict['depth_head.scratch.output_conv2.2.weight'] - if pretrained_weights.shape[0] < output_channels: - logger.info("Appending weights to pretrained network for output channels") - new_channel_weights = torch.randn(output_channels-1, pretrained_weights.shape[1], pretrained_weights.shape[2], pretrained_weights.shape[3]) * 0.01 - new_weights = torch.cat((pretrained_weights, new_channel_weights), dim=0) - pretrained_dict['depth_head.scratch.output_conv2.2.weight'] = new_weights - - pretrained_bias_weights = pretrained_dict['depth_head.scratch.output_conv2.2.bias'] - new_bias = torch.randn(output_channels-1) * 0.01 - new_bias_weights = torch.cat((pretrained_bias_weights, new_bias), dim=0) - pretrained_dict['depth_head.scratch.output_conv2.2.bias'] = new_bias_weights - - model.load_state_dict(pretrained_dict, strict=False) + state_dict = torch.load(pretrained_from, map_location='cpu') + if 'model' in state_dict.keys(): + state_dict = state_dict['model'] + + if use_depth_prior: + logger.info("Using depth prior") + + weights = state_dict['pretrained.patch_embed.proj.weight'] + if weights.shape[1] < 4: + logger.info("Appending weights to network for input channels") + new_channel_weights = torch.randn(weights.shape[0], 1, weights.shape[2], weights.shape[3]) * 0.01 + new_weights = torch.cat((weights, new_channel_weights), dim=1) + state_dict['pretrained.patch_embed.proj.weight'] = new_weights + + if output_channels > 1: + logger.info(f"Using {output_channels} output channels") + weights = state_dict['depth_head.scratch.output_conv2.2.weight'] + if weights.shape[0] < output_channels: + logger.info("Appending weights to pretrained network for output channels") + new_channel_weights = torch.randn(output_channels-1, state_dict.shape[1], state_dict.shape[2], state_dict.shape[3]) * 0.01 + new_weights = torch.cat((state_dict, new_channel_weights), dim=0) + state_dict['depth_head.scratch.output_conv2.2.weight'] = new_weights + + weights = state_dict['depth_head.scratch.output_conv2.2.bias'] + new_bias = torch.randn(output_channels-1) * 0.01 + new_bias_weights = torch.cat((state_dict, new_bias), dim=0) + state_dict['depth_head.scratch.output_conv2.2.bias'] = new_bias_weights + + model.load_state_dict(state_dict, strict=False) return model diff --git a/radarmeetsvision/utils.py b/radarmeetsvision/utils.py index fd1029f..686f651 100644 --- a/radarmeetsvision/utils.py +++ b/radarmeetsvision/utils.py @@ -10,12 +10,12 @@ import torch from datetime import datetime -def get_device(min_memory_gb=8): +def get_device(min_memory_gb=3): device_str = 'cpu' if torch.cuda.is_available(): device = torch.cuda.get_device_properties(0) total_memory_gb = device.total_memory / (1024 ** 3) - if total_memory_gb > min_memory_gb: + if total_memory_gb >= min_memory_gb: device_str = 'cuda' return device_str diff --git a/tests/test_metric_depth_network.py b/tests/test_metric_depth_network.py index 7c58314..c943a3f 100644 --- a/tests/test_metric_depth_network.py +++ b/tests/test_metric_depth_network.py @@ -1,5 +1,6 @@ import cv2 import numpy as np +import time import unittest from .context import * @@ -25,3 +26,24 @@ def test_inference(self): # THEN: The output depth is valid self.assertFalse(np.isnan(depth).any()) + + def test_inference_time(self): + # GIVEN: A metric depth anything V2 network + model = get_model(None, True, 'vitb', 120.0, 2) + device = get_device() + print(f"Using device {device}") + model = model.to(device).eval() + + # WHEN: Random matrices are inferred + total_time = 0 + N = 10 + if device != 'cpu': + N = 500 + + for i in range(N): + img = torch.rand((1, 4, 518, 518), device=device, requires_grad=False) + start_time = time.monotonic() + prediction = model.forward(img) + total_time += (time.monotonic() - start_time) + print(total_time/float(i+1)) + print(f"Average time per iteration: {total_time/float(N)}")