From c2747d1e87ba9d26e8059f827ee0b035ed445193 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Mon, 1 Apr 2024 21:03:53 +0000 Subject: [PATCH] Improve: SQLite download process --- python/scripts/test_sqlite.py | 5 +-- python/usearch/__init__.py | 57 ++++++++++++++++++++--------------- sqlite/README.md | 2 +- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/python/scripts/test_sqlite.py b/python/scripts/test_sqlite.py index 8ecb5dcc..baf0f7c2 100644 --- a/python/scripts/test_sqlite.py +++ b/python/scripts/test_sqlite.py @@ -9,8 +9,9 @@ import usearch -found_sqlite_path = usearch.sqlite -if found_sqlite_path is None: +try: + found_sqlite_path = usearch.sqlite_path() +except FileNotFoundError: pytest.skip(reason="Can't find an SQLite installation", allow_module_level=True) diff --git a/python/usearch/__init__.py b/python/usearch/__init__.py index f3006410..bb1bf13d 100644 --- a/python/usearch/__init__.py +++ b/python/usearch/__init__.py @@ -20,20 +20,23 @@ class BinaryManager: def __init__(self, version: Optional[str] = None): if version is None: version = __version__ - self.version = version - - def sqlite_download_url(self) -> str: - """ - Constructs a download URL for the `usearch_sqlite` binary based on the operating system, architecture, and version. - - Args: - version (str): The version of the binary to download. + self.version = version or __version__ + self.download_dir = self.determine_download_dir() + + @staticmethod + def determine_download_dir(): + # Check if running within a virtual environment + virtual_env = os.getenv("VIRTUAL_ENV") + if virtual_env: + # Use a subdirectory within the virtual environment for binaries + return os.path.join(virtual_env, "bin", "usearch_binaries") + else: + # Fallback to a directory in the user's home folder + home_dir = os.path.expanduser("~") + return os.path.join(home_dir, ".usearch", "binaries") - Returns: - A string representing the download URL. - """ + def sqlite_file_name(self) -> str: version = self.version - base_url = "https://github.com/unum-cloud/usearch/releases/download" os_map = {"Linux": "linux", "Windows": "windows", "Darwin": "macos"} arch_map = { "x86_64": "amd64" if platform.system() != "Darwin" else "x86_64", @@ -47,6 +50,12 @@ def sqlite_download_url(self) -> str: arch_part = arch_map.get(arch, "") extension = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"}.get(platform.system(), "") filename = f"usearch_sqlite_{os_part}_{arch_part}_{version}.{extension}" + return filename + + def sqlite_download_url(self) -> str: + version = self.version + filename = self.sqlite_file_name() + base_url = "https://github.com/unum-cloud/usearch/releases/download" url = f"{base_url}/v{version}/{filename}" return url @@ -66,7 +75,6 @@ def download_binary(self, url: str, dest_folder: str) -> str: urllib.request.urlretrieve(url, dest_path) return dest_path - @property def sqlite_found_or_downloaded(self) -> Optional[str]: """ Attempts to locate the pre-installed `usearch_sqlite` binary. @@ -89,20 +97,16 @@ def sqlite_found_or_downloaded(self) -> Optional[str]: return os.path.join(root, file).removesuffix(file_extension) # Check a temporary directory (assuming the binary might be downloaded from a GitHub release) - temp_dir = tempfile.gettempdir() - for root, _, files in os.walk(temp_dir): - for file in files: - if file.endswith(file_extension) and "usearch_sqlite" in file: - return os.path.join(root, file).removesuffix(file_extension) + local_path = os.path.join(self.download_dir, self.sqlite_file_name()) + if os.path.exists(local_path): + return local_path.removesuffix(file_extension) # If not found locally, warn the user and download from GitHub - temp_dir = tempfile.gettempdir() warnings.warn("Will download `usearch_sqlite` binary from GitHub.", UserWarning) - - # If the download fails due to HTTPError (e.g., 404 Not Found), like a missing lib version try: - binary_path = self.download_binary(self.sqlite_download_url(), temp_dir) + binary_path = self.download_binary(self.sqlite_download_url(), self.download_dir) except HTTPError as e: + # If the download fails due to HTTPError (e.g., 404 Not Found), like a missing lib version if e.code == 404: warnings.warn(f"Download failed: {e.url} could not be found.", UserWarning) else: @@ -117,6 +121,9 @@ def sqlite_found_or_downloaded(self) -> Optional[str]: return None -# Use the function to set the `sqlite` computed property -binary_manager = BinaryManager() -sqlite = binary_manager.sqlite_found_or_downloaded +def sqlite_path(version: str = None) -> str: + manager = BinaryManager(version=version) + result = manager.sqlite_found_or_downloaded() + if result is None: + raise FileNotFoundError("Failed to find or download `usearch_sqlite` binary.") + return result diff --git a/sqlite/README.md b/sqlite/README.md index a74d105c..21b313e8 100644 --- a/sqlite/README.md +++ b/sqlite/README.md @@ -30,7 +30,7 @@ import usearch conn = sqlite3.connect(":memory:") conn.enable_load_extension(True) -conn.load_extension(usearch.sqlite) +conn.load_extension(usearch.sqlite_path()) ``` Afterwards, the following script should work fine.