Skip to content

Commit

Permalink
Merge pull request #58 from blacklanternsecurity/cli-tests
Browse files Browse the repository at this point in the history
Caching functionality
  • Loading branch information
TheTechromancer committed Jun 20, 2024
2 parents c4c2f61 + facf176 commit 6a135a3
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 43 deletions.
78 changes: 48 additions & 30 deletions cloudcheck/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import httpx
import pickle
import asyncio
import logging
import importlib
from pathlib import Path
from contextlib import suppress
from datetime import datetime, timedelta

from .base import BaseCloudProvider
Expand Down Expand Up @@ -35,33 +37,55 @@ class CloudProviders:
json_path = code_path / "cloud_providers.json"
cache_path = code_path / ".cloudcheck_cache"

def __init__(self, httpx_client=None):
def __init__(self):
self.providers = {}
self._httpx_client = httpx_client
self.cache_key = None
self.load_from_json()

def load_from_json(self):
if self.json_path.is_file():
with open(self.json_path) as f:
try:
j = json.load(f)
for k in list(j):
j[k.lower()] = j.pop(k)
except Exception as e:
log.warning(f"Failed to parsed JSON at {self.json_path}: {e}")
return
def load_from_json(self, force=False):
# loading from a pickled cache is about 1 second faster than loading from JSON
if (not force) and self.cache_path.is_file():
self.load_from_cache()
else:
if self.json_path.is_file():
with open(self.json_path) as f:
try:
j = json.load(f)
for k in list(j):
j[k.lower()] = j.pop(k)
except Exception as e:
log.warning(f"Failed to parse JSON at {self.json_path}: {e}")
return
for provider_name, provider_class in providers.items():
provider_name = provider_name.lower()
provider_json = j.get(provider_name, {})
self.providers[provider_name] = provider_class(provider_json)
self.cache_key = self.json_path.stat()
else:
for provider_name, provider_class in providers.items():
provider_name = provider_name.lower()
provider_json = j.get(provider_name, {})
self.providers[provider_name] = provider_class(
provider_json, self.httpx_client
)
self.cache_key = self.json_path.stat()
else:
for provider_name, provider_class in providers.items():
provider_name = provider_name.lower()
self.providers[provider_name] = provider_class(None, self.httpx_client)
self.providers[provider_name] = provider_class(None)
self.write_cache()

def load_from_cache(self):
with open(self.cache_path, "rb") as f:
try:
self.providers = pickle.load(f)
except Exception as e:
with suppress(Exception):
self.cache_path.unlink()
log.warning(
f"Failed to load cloudcheck cache at {self.cache_path}: {e}"
)

def write_cache(self):
with open(self.cache_path, "wb") as f:
try:
pickle.dump(self.providers, f)
except Exception as e:
log.warning(
f"Failed to write cloudcheck cache to {self.cache_path}: {e}"
)

def check(self, host):
host = make_ip_type(host)
Expand All @@ -79,13 +103,13 @@ async def update(self, days=1, force=False):
if self.last_updated > oldest_allowed and not force:
return
try:
response = await self.httpx_client.get(self.json_url)
response = await httpx.get(self.json_url)
except Exception as e:
error = e
if response is not None and response.status_code == 200 and response.content:
with open(self.json_path, "wb") as f:
f.write(response.content)
self.load_from_json()
self.load_from_json(force=True)
else:
log.warning(
f"Failed to retrieve update from {self.json_url} (response: {response}, error: {error})"
Expand All @@ -100,7 +124,7 @@ async def update_from_sources(self):
json.dump(
self.to_json(), f, sort_keys=True, indent=4, cls=CustomJSONEncoder
)
self.load_from_json()
self.load_from_json(force=True)

def to_json(self):
d = {}
Expand All @@ -115,12 +139,6 @@ def last_updated(self):
else:
return datetime.min

@property
def httpx_client(self):
if self._httpx_client is None:
self._httpx_client = httpx.AsyncClient(verify=False)
return self._httpx_client

def __iter__(self):
yield from self.providers.values()

Expand Down
24 changes: 11 additions & 13 deletions cloudcheck/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
asndb = None


base_headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36"
}


class CloudProviderJSON(BaseModel):
name: str = ""
domains: List[str] = []
Expand All @@ -39,13 +44,9 @@ class BaseCloudProvider:
regexes = {}
provider_type = "cloud"
ips_url = ""
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/104.0.5112.79 Safari/537.36"
}
asns = []

def __init__(self, j, httpx_client=None):
self._httpx_client = httpx_client
def __init__(self, j):
self._log = None
self.ranges = set()
self.radix = RadixTarget()
Expand Down Expand Up @@ -81,8 +82,11 @@ async def update(self):
self.last_updated = datetime.now()
self.ranges = self.get_subnets()
if self.ips_url:
response = await self.httpx_client.get(
self.ips_url, follow_redirects=True, headers=self.headers
response = await httpx.get(
self.ips_url,
follow_redirects=True,
headers=base_headers,
verify=False,
)
ranges = self.parse_response(response)
if ranges:
Expand Down Expand Up @@ -133,12 +137,6 @@ def to_json(self):
bucket_name_regex=self.bucket_name_regex,
).model_dump()

@property
def httpx_client(self):
if self._httpx_client is None:
self._httpx_client = httpx.AsyncClient(verify=False)
return self._httpx_client

def parse_response(self, response):
pass

Expand Down

0 comments on commit 6a135a3

Please sign in to comment.