Skip to content

Commit

Permalink
Wget: add caching
Browse files Browse the repository at this point in the history
  • Loading branch information
mcgov committed Nov 12, 2024
1 parent 7098abd commit e69a23b
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions lisa/base_tools/wget.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import TYPE_CHECKING, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type
from urllib.parse import urlparse

from retry import retry
Expand All @@ -24,6 +24,10 @@ class Wget(Tool):
def command(self) -> str:
return "wget"

def _initialize(self, *args: Any, **kwargs: Any) -> None:
self.__filename_result_cache: Dict[str, str] = dict()
return super()._initialize(*args, **kwargs)

@property
def can_install(self) -> bool:
return True
Expand All @@ -45,6 +49,20 @@ def get(
force_run: bool = False,
timeout: int = 600,
) -> str:
if not force_run:
# return a cached result if one exists
try:
filename = self.__filename_result_cache[url]
return filename
except KeyError:
pass
else:
# remove the old key
try:
self.__filename_result_cache.pop(url)
except KeyError:
pass

is_valid_url(url)

if not filename:
Expand Down Expand Up @@ -88,25 +106,27 @@ def get(
f" stdout: {command_result.stdout}"
f" templog: {temp_log}"
)
self.node.tools[Rm].remove_file(log_file, sudo=sudo)
else:
download_file_path = download_path

if command_result.is_timeout:
raise LisaTimeoutException(
f"wget command is timed out after {timeout} seconds."
)
actual_file_path = self.node.execute(
ls_result = self.node.execute(
f"ls {download_file_path}",
shell=True,
sudo=sudo,
expected_exit_code=0,
expected_exit_code_failure_message="File path does not exist, "
f"{download_file_path}",
)
actual_file_path = ls_result.stdout.strip()
self.__filename_result_cache[url] = actual_file_path
if executable:
self.node.execute(f"chmod +x {actual_file_path}", sudo=sudo)
self.node.tools[Rm].remove_file(log_file, sudo=sudo)
return actual_file_path.stdout
return actual_file_path

def verify_internet_access(self) -> bool:
try:
Expand Down Expand Up @@ -159,6 +179,19 @@ def get(
force_run: bool = False,
timeout: int = 600,
) -> str:
if not force_run:
# return a cached result if one exists
try:
filename = self.__filename_result_cache[url]
return filename
except KeyError:
pass
else:
# remove the old key
try:
self.__filename_result_cache.pop(url)
except KeyError:
pass
ls = self.node.tools[Ls]

if not filename:
Expand Down Expand Up @@ -186,5 +219,5 @@ def get(
force_run=force_run,
timeout=timeout,
)

self.__filename_result_cache[url] = download_path
return download_path

0 comments on commit e69a23b

Please sign in to comment.