diff --git a/cloudcheck/providers/__init__.py b/cloudcheck/providers/__init__.py index f4f8eac..43dc94a 100644 --- a/cloudcheck/providers/__init__.py +++ b/cloudcheck/providers/__init__.py @@ -1,9 +1,11 @@ import json import httpx +import pickle import asyncio import logging import importlib from pathlib import Path +from contextlib import suppress from datetime import datetime, timedelta from .base import BaseCloudProvider @@ -35,33 +37,55 @@ class CloudProviders: json_path = code_path / "cloud_providers.json" cache_path = code_path / ".cloudcheck_cache" - def __init__(self, httpx_client=None): + def __init__(self): self.providers = {} - self._httpx_client = httpx_client self.cache_key = None self.load_from_json() - def load_from_json(self): - if self.json_path.is_file(): - with open(self.json_path) as f: - try: - j = json.load(f) - for k in list(j): - j[k.lower()] = j.pop(k) - except Exception as e: - log.warning(f"Failed to parsed JSON at {self.json_path}: {e}") - return + def load_from_json(self, force=False): + # loading from a pickled cache is about 1 second faster than loading from JSON + if (not force) and self.cache_path.is_file(): + self.load_from_cache() + else: + if self.json_path.is_file(): + with open(self.json_path) as f: + try: + j = json.load(f) + for k in list(j): + j[k.lower()] = j.pop(k) + except Exception as e: + log.warning(f"Failed to parse JSON at {self.json_path}: {e}") + return + for provider_name, provider_class in providers.items(): + provider_name = provider_name.lower() + provider_json = j.get(provider_name, {}) + self.providers[provider_name] = provider_class(provider_json) + self.cache_key = self.json_path.stat() + else: for provider_name, provider_class in providers.items(): provider_name = provider_name.lower() - provider_json = j.get(provider_name, {}) - self.providers[provider_name] = provider_class( - provider_json, self.httpx_client - ) - self.cache_key = self.json_path.stat() - else: - for provider_name, provider_class in providers.items(): - provider_name = provider_name.lower() - self.providers[provider_name] = provider_class(None, self.httpx_client) + self.providers[provider_name] = provider_class(None) + self.write_cache() + + def load_from_cache(self): + with open(self.cache_path, "rb") as f: + try: + self.providers = pickle.load(f) + except Exception as e: + with suppress(Exception): + self.cache_path.unlink() + log.warning( + f"Failed to load cloudcheck cache at {self.cache_path}: {e}" + ) + + def write_cache(self): + with open(self.cache_path, "wb") as f: + try: + pickle.dump(self.providers, f) + except Exception as e: + log.warning( + f"Failed to write cloudcheck cache to {self.cache_path}: {e}" + ) def check(self, host): host = make_ip_type(host) @@ -79,13 +103,13 @@ async def update(self, days=1, force=False): if self.last_updated > oldest_allowed and not force: return try: - response = await self.httpx_client.get(self.json_url) + response = await httpx.get(self.json_url) except Exception as e: error = e if response is not None and response.status_code == 200 and response.content: with open(self.json_path, "wb") as f: f.write(response.content) - self.load_from_json() + self.load_from_json(force=True) else: log.warning( f"Failed to retrieve update from {self.json_url} (response: {response}, error: {error})" @@ -100,7 +124,7 @@ async def update_from_sources(self): json.dump( self.to_json(), f, sort_keys=True, indent=4, cls=CustomJSONEncoder ) - self.load_from_json() + self.load_from_json(force=True) def to_json(self): d = {} @@ -115,12 +139,6 @@ def last_updated(self): else: return datetime.min - @property - def httpx_client(self): - if self._httpx_client is None: - self._httpx_client = httpx.AsyncClient(verify=False) - return self._httpx_client - def __iter__(self): yield from self.providers.values() diff --git a/cloudcheck/providers/base.py b/cloudcheck/providers/base.py index 4485225..2def36e 100644 --- a/cloudcheck/providers/base.py +++ b/cloudcheck/providers/base.py @@ -16,6 +16,11 @@ asndb = None +base_headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36" +} + + class CloudProviderJSON(BaseModel): name: str = "" domains: List[str] = [] @@ -39,13 +44,9 @@ class BaseCloudProvider: regexes = {} provider_type = "cloud" ips_url = "" - headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.79 Safari/537.36" - } asns = [] - def __init__(self, j, httpx_client=None): - self._httpx_client = httpx_client + def __init__(self, j): self._log = None self.ranges = set() self.radix = RadixTarget() @@ -81,8 +82,11 @@ async def update(self): self.last_updated = datetime.now() self.ranges = self.get_subnets() if self.ips_url: - response = await self.httpx_client.get( - self.ips_url, follow_redirects=True, headers=self.headers + response = await httpx.get( + self.ips_url, + follow_redirects=True, + headers=base_headers, + verify=False, ) ranges = self.parse_response(response) if ranges: @@ -133,12 +137,6 @@ def to_json(self): bucket_name_regex=self.bucket_name_regex, ).model_dump() - @property - def httpx_client(self): - if self._httpx_client is None: - self._httpx_client = httpx.AsyncClient(verify=False) - return self._httpx_client - def parse_response(self, response): pass