diff --git a/benchmark_qed/data/cli.py b/benchmark_qed/data/cli.py index 64e2083..fdc835b 100644 --- a/benchmark_qed/data/cli.py +++ b/benchmark_qed/data/cli.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Microsoft Corporation. """Data downloader CLI.""" +import zipfile from enum import StrEnum from pathlib import Path from typing import Annotated @@ -51,3 +52,9 @@ def download( response = requests.get(api_url, timeout=60) output_file = output_dir / f"{dataset}.zip" output_file.write_bytes(response.content) + + with zipfile.ZipFile(output_file, "r") as zip_ref: + zip_ref.extractall(output_dir) + + output_file.unlink() # Remove the zip file after extraction + typer.echo(f"Dataset {dataset} downloaded and extracted to {output_dir}.")