11from __future__ import annotations
22
33import abc
4+ import concurrent .futures
45import copy
56import logging
67import os
@@ -34,6 +35,8 @@ def __init__( # noqa: PLR0913
3435 download_timeout : int = 125 ,
3536 logger : logging .Logger | None = None ,
3637 security_reference_url : str | None = None ,
38+ skip_redownload : bool = False ,
39+ max_workers : int = 8 ,
3740 ):
3841 if not fixdater :
3942 fixdater = fixdate .default_finder (workspace )
@@ -46,6 +49,8 @@ def __init__( # noqa: PLR0913
4649 self .security_reference_url = (
4750 security_reference_url .strip ("/" ) if security_reference_url else self ._security_reference_url_
4851 )
52+ self .skip_redownload = skip_redownload
53+ self .max_workers = max_workers
4954
5055 if not logger :
5156 logger = logging .getLogger (self .__class__ .__name__ )
@@ -105,9 +110,21 @@ def __init__(# noqa: PLR0913
105110 download_timeout : int = 125 ,
106111 logger : logging .Logger | None = None ,
107112 security_reference_url : str | None = None ,
113+ skip_redownload : bool = False ,
114+ max_workers : int = 8 ,
108115 ):
109116 self ._db_filename = self ._extract_filename_from_url (url )
110- super ().__init__ (workspace , url , namespace , fixdater , download_timeout , logger , security_reference_url )
117+ super ().__init__ (
118+ workspace ,
119+ url ,
120+ namespace ,
121+ fixdater ,
122+ download_timeout ,
123+ logger ,
124+ security_reference_url ,
125+ skip_redownload = skip_redownload ,
126+ max_workers = max_workers ,
127+ )
111128
112129 def _download (self ) -> None :
113130 if not os .path .exists (self .input_dir_path ):
@@ -119,6 +136,11 @@ def _download(self) -> None:
119136 self .logger .info (f"downloading { self .namespace } secdb { self .url } " )
120137 r = http .get (self .url , self .logger , stream = True , timeout = self .download_timeout )
121138 file_path = os .path .join (self .input_dir_path , self ._db_filename )
139+ # if the file already exists and skip_redownload is True, skip writing the file again. This is to avoid
140+ # unnecessary redownloading and rewriting of the same file, which can save time on subsequent runs.
141+ if self .skip_redownload and os .path .exists (file_path ):
142+ self .logger .info (f"skipping download of { self .namespace } secdb since file already exists at { file_path } " )
143+ return
122144 with open (file_path , "wb" ) as fp :
123145 for chunk in r .iter_content ():
124146 fp .write (chunk )
@@ -236,9 +258,14 @@ class OSVParser(Parser):
236258 _input_dir_ = "osv"
237259
238260 def _download (self ) -> None :
261+ '''
262+ Download all OSV entry files based on the index file at self.url, which should point to the
263+ top level all.json file. For each entry in the index, we construct the URL for the individual
264+ entry file and download it to the input directory.
265+ '''
239266 if not os .path .exists (self .input_dir_path ):
240267 os .makedirs (self .input_dir_path , exist_ok = True )
241-
268+
242269 self .fixdater .download ()
243270
244271 try :
@@ -249,25 +276,48 @@ def _download(self) -> None:
249276 index = orjson .loads (r .content )
250277
251278 base_url = self .url .rsplit ("/" , 1 )[0 ]
252- for entry in index :
253- # for each entry pointed to by the index, pull down the full JSON file
254- filename = f"{ entry ['id' ]} .json"
255- entry_url = f"{ base_url } /{ filename } "
256- r = http .get (self .url , self .logger , stream = True , timeout = self .download_timeout )
257- file_path = os .path .join (self .input_dir_path , filename )
258- with open (file_path , "wb" ) as fp :
259- for chunk in r .iter_content ():
260- fp .write (chunk )
279+ # Download all entries in the index concurrently using a thread pool,
280+ # which should speed up the download process significantly since there are thousands of entries.
281+ # We construct the URL for each entry by appending the entry ID and .json to the base URL
282+ # e.g. https://packages.cgr.dev/chainguard/v2/osv/CGA-2255-2h2p-73q2.json
283+ with concurrent .futures .ThreadPoolExecutor (max_workers = self .max_workers ) as executor :
284+ futures = [
285+ executor .submit (self ._download_single_file , f"{ base_url } /{ entry ['id' ]} .json" , f"{ entry ['id' ]} .json" )
286+ for entry in index
287+ ]
288+ # surface the first exception (if any) — matches prior behavior where a single
289+ # failure aborted the batch via the outer try/except
290+ done , _not_done = concurrent .futures .wait (futures , return_when = concurrent .futures .FIRST_EXCEPTION )
291+ for future in done :
292+ future .result ()
261293 except Exception :
262294 self .logger .exception (f"ignoring error processing osv for { self .url } " )
263295
296+ def _download_single_file (self , url : str , filename : str ) -> None :
297+ '''
298+ Download a single OSV entry file given its URL and the desired filename.
299+ '''
300+ file_path = os .path .join (self .input_dir_path , filename )
301+ # if the file already exists and skip_redownload is True, skip writing the file again. This is to avoid
302+ # unnecessary redownloading and rewriting of the same file, which can save time on subsequent
303+ # runs.
304+ if self .skip_redownload and os .path .exists (file_path ):
305+ self .logger .info (f"skipping download of { self .namespace } osv entry { filename } since file already exists" )
306+ return
307+ self .logger .info (f"downloading { self .namespace } osv entry { filename } " )
308+ r = http .get (url , self .logger , stream = True , timeout = self .download_timeout )
309+ with open (file_path , "wb" ) as fp :
310+ for chunk in r .iter_content ():
311+ fp .write (chunk )
312+
264313 def _load (self ) -> Generator [tuple [str , dict [str , Any ]], None , None ]:
265314 try :
266315 # for each file we have downloaded, which should be every json file in the index, load it
267316 # and yield the data for normalization
268317 for filename in os .listdir (self .input_dir_path ):
269318 if not filename .endswith (".json" ):
270319 continue
320+ self .logger .info (f"loading { self .namespace } osv data from { filename } " )
271321 with open (os .path .join (self .input_dir_path , filename )) as fh :
272322 data = orjson .loads (fh .read ())
273323 yield self ._release_ , data
@@ -289,4 +339,4 @@ def _normalize(self, release: str, data: dict[str, Any]) -> dict[str, Any]: # n
289339 # we map the osv id to the osv data to keep consistency in the secdb parser, which
290340 # does this for ease of identifying the associated vulnerability when writing records.
291341 # IE: {"CGA-1234-5678-9abc": {<full osv record>}}
292- return {data ["id" ]: data }
342+ return {data ['id' ]: data }
0 commit comments