Skip to content

Commit 52af1d3

Browse files
committed
parallelize osv download for efficiency
Signed-off-by: crosleyzack <mail@crosleyzack.com>
1 parent 20de888 commit 52af1d3

3 files changed

Lines changed: 95 additions & 12 deletions

File tree

src/vunnel/providers/chainguard/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ class Config:
2727
osv_url: str = "https://packages.cgr.dev/chainguard/v2/osv/all.json"
2828
# Override with VUNNEL_PROVIDERS_CHAINGUARD_USE_OSV
2929
use_osv: bool = False
30+
# Override with VUNNEL_PROVIDERS_CHAINGUARD_SKIP_REDOWNLOAD
31+
skip_redownload: bool = False
32+
# Override with VUNNEL_PROVIDERS_CHAINGUARD_OSV_MAX_WORKERS
33+
osv_max_workers: int = 8
3034

3135

3236
class Provider(provider.Provider):
@@ -54,6 +58,8 @@ def __init__(self, root: str, config: Config | None = None):
5458
namespace=self._namespace,
5559
download_timeout=self.config.request_timeout,
5660
logger=self.logger,
61+
skip_redownload=self.config.skip_redownload,
62+
max_workers=self.config.osv_max_workers,
5763
)
5864
self.schema = schema.OSVSchema(version="1.7.0")
5965
else:
@@ -63,6 +69,7 @@ def __init__(self, root: str, config: Config | None = None):
6369
namespace=self._namespace,
6470
download_timeout=self.config.request_timeout,
6571
logger=self.logger,
72+
skip_redownload=self.config.skip_redownload,
6673
)
6774
self.feed_url = self.config.secdb_url
6875
self.schema = schema.OSSchema()

src/vunnel/providers/wolfi/parser.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4+
import concurrent.futures
45
import copy
56
import logging
67
import os
@@ -34,6 +35,8 @@ def __init__( # noqa: PLR0913
3435
download_timeout: int = 125,
3536
logger: logging.Logger | None = None,
3637
security_reference_url: str | None = None,
38+
skip_redownload: bool = False,
39+
max_workers: int = 8,
3740
):
3841
if not fixdater:
3942
fixdater = fixdate.default_finder(workspace)
@@ -46,6 +49,8 @@ def __init__( # noqa: PLR0913
4649
self.security_reference_url = (
4750
security_reference_url.strip("/") if security_reference_url else self._security_reference_url_
4851
)
52+
self.skip_redownload = skip_redownload
53+
self.max_workers = max_workers
4954

5055
if not logger:
5156
logger = logging.getLogger(self.__class__.__name__)
@@ -105,9 +110,21 @@ def __init__(# noqa: PLR0913
105110
download_timeout: int = 125,
106111
logger: logging.Logger | None = None,
107112
security_reference_url: str | None = None,
113+
skip_redownload: bool = False,
114+
max_workers: int = 8,
108115
):
109116
self._db_filename = self._extract_filename_from_url(url)
110-
super().__init__(workspace, url, namespace, fixdater, download_timeout, logger, security_reference_url)
117+
super().__init__(
118+
workspace,
119+
url,
120+
namespace,
121+
fixdater,
122+
download_timeout,
123+
logger,
124+
security_reference_url,
125+
skip_redownload=skip_redownload,
126+
max_workers=max_workers,
127+
)
111128

112129
def _download(self) -> None:
113130
if not os.path.exists(self.input_dir_path):
@@ -119,6 +136,11 @@ def _download(self) -> None:
119136
self.logger.info(f"downloading {self.namespace} secdb {self.url}")
120137
r = http.get(self.url, self.logger, stream=True, timeout=self.download_timeout)
121138
file_path = os.path.join(self.input_dir_path, self._db_filename)
139+
# if the file already exists and skip_redownload is True, skip writing the file again. This is to avoid
140+
# unnecessary redownloading and rewriting of the same file, which can save time on subsequent runs.
141+
if self.skip_redownload and os.path.exists(file_path):
142+
self.logger.info(f"skipping download of {self.namespace} secdb since file already exists at {file_path}")
143+
return
122144
with open(file_path, "wb") as fp:
123145
for chunk in r.iter_content():
124146
fp.write(chunk)
@@ -236,9 +258,14 @@ class OSVParser(Parser):
236258
_input_dir_ = "osv"
237259

238260
def _download(self) -> None:
261+
'''
262+
Download all OSV entry files based on the index file at self.url, which should point to the
263+
top level all.json file. For each entry in the index, we construct the URL for the individual
264+
entry file and download it to the input directory.
265+
'''
239266
if not os.path.exists(self.input_dir_path):
240267
os.makedirs(self.input_dir_path, exist_ok=True)
241-
268+
242269
self.fixdater.download()
243270

244271
try:
@@ -249,25 +276,48 @@ def _download(self) -> None:
249276
index = orjson.loads(r.content)
250277

251278
base_url = self.url.rsplit("/", 1)[0]
252-
for entry in index:
253-
# for each entry pointed to by the index, pull down the full JSON file
254-
filename = f"{entry['id']}.json"
255-
entry_url = f"{base_url}/{filename}"
256-
r = http.get(self.url, self.logger, stream=True, timeout=self.download_timeout)
257-
file_path = os.path.join(self.input_dir_path, filename)
258-
with open(file_path, "wb") as fp:
259-
for chunk in r.iter_content():
260-
fp.write(chunk)
279+
# Download all entries in the index concurrently using a thread pool,
280+
# which should speed up the download process significantly since there are thousands of entries.
281+
# We construct the URL for each entry by appending the entry ID and .json to the base URL
282+
# e.g. https://packages.cgr.dev/chainguard/v2/osv/CGA-2255-2h2p-73q2.json
283+
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
284+
futures = [
285+
executor.submit(self._download_single_file, f"{base_url}/{entry['id']}.json", f"{entry['id']}.json")
286+
for entry in index
287+
]
288+
# surface the first exception (if any) — matches prior behavior where a single
289+
# failure aborted the batch via the outer try/except
290+
done, _not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_EXCEPTION)
291+
for future in done:
292+
future.result()
261293
except Exception:
262294
self.logger.exception(f"ignoring error processing osv for {self.url}")
263295

296+
def _download_single_file(self, url: str, filename: str) -> None:
297+
'''
298+
Download a single OSV entry file given its URL and the desired filename.
299+
'''
300+
file_path = os.path.join(self.input_dir_path, filename)
301+
# if the file already exists and skip_redownload is True, skip writing the file again. This is to avoid
302+
# unnecessary redownloading and rewriting of the same file, which can save time on subsequent
303+
# runs.
304+
if self.skip_redownload and os.path.exists(file_path):
305+
self.logger.info(f"skipping download of {self.namespace} osv entry {filename} since file already exists")
306+
return
307+
self.logger.info(f"downloading {self.namespace} osv entry {filename}")
308+
r = http.get(url, self.logger, stream=True, timeout=self.download_timeout)
309+
with open(file_path, "wb") as fp:
310+
for chunk in r.iter_content():
311+
fp.write(chunk)
312+
264313
def _load(self) -> Generator[tuple[str, dict[str, Any]], None, None]:
265314
try:
266315
# for each file we have downloaded, which should be every json file in the index, load it
267316
# and yield the data for normalization
268317
for filename in os.listdir(self.input_dir_path):
269318
if not filename.endswith(".json"):
270319
continue
320+
self.logger.info(f"loading {self.namespace} osv data from {filename}")
271321
with open(os.path.join(self.input_dir_path, filename)) as fh:
272322
data = orjson.loads(fh.read())
273323
yield self._release_, data
@@ -289,4 +339,4 @@ def _normalize(self, release: str, data: dict[str, Any]) -> dict[str, Any]: # n
289339
# we map the osv id to the osv data to keep consistency in the secdb parser, which
290340
# does this for ease of identifying the associated vulnerability when writing records.
291341
# IE: {"CGA-1234-5678-9abc": {<full osv record>}}
292-
return {data["id"]: data}
342+
return {data['id']: data}

tests/unit/providers/chainguard/test_chainguard.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,32 @@ def test_parser_selection(
3636
assert p.schema.name == expected_schema_name
3737

3838

39+
@pytest.mark.parametrize(
40+
("use_osv", "expected_parser_cls"),
41+
[
42+
(False, SecDBParser),
43+
(True, OSVParser),
44+
],
45+
)
46+
def test_config_propagates_to_parser(helpers, auto_fake_fixdate_finder, use_osv, expected_parser_cls):
47+
workspace = helpers.provider_workspace_helper(name=Provider.name())
48+
49+
c = Config(use_osv=use_osv, skip_redownload=True, osv_max_workers=16)
50+
c.runtime.result_store = result.StoreStrategy.FLAT_FILE
51+
p = Provider(root=workspace.root, config=c)
52+
53+
assert isinstance(p.parser, expected_parser_cls)
54+
assert p.parser.skip_redownload is True
55+
if use_osv:
56+
assert p.parser.max_workers == 16
57+
58+
59+
def test_config_defaults():
60+
c = Config()
61+
assert c.skip_redownload is False
62+
assert c.osv_max_workers == 8
63+
64+
3965
def test_provider_schema(helpers, disable_get_requests, auto_fake_fixdate_finder):
4066
workspace = helpers.provider_workspace_helper(name=Provider.name())
4167

0 commit comments

Comments
 (0)