diff --git a/simple_shapes_dataset/cli/download.py b/simple_shapes_dataset/cli/download.py index cb82123..2123ff4 100644 --- a/simple_shapes_dataset/cli/download.py +++ b/simple_shapes_dataset/cli/download.py @@ -47,11 +47,30 @@ def downlad_file(url: str, path: Path): "Useful if you need the old version of the dataset." ), ) -def download_dataset(path: Path, force: bool, no_migration: bool): +@click.option( + "--ckpturl", + default=DATASET_URL, + help=( + "Dataset URL. Defaults to " + "`https://zenodo.org/records/8112838/files/simple_shapes_dataset.tar.gz`" + ), +) +@click.option( + "--name", + default="simple_shapes_dataset", + help="name of the folder to download. Defaults to `simple_shapes_dataset`", +) +def download_dataset( + path: Path, + force: bool, + no_migration: bool, + ckpturl: str = DATASET_URL, + name: str = "simple_shapes_dataset", +): click.echo(f"Downloading in {str(path)}.") - dataset_path = path / "simple_shapes_dataset" - archive_path = path / "simple_shapes_dataset.tar.gz" + dataset_path = path / name + archive_path = path / f"{name}.tar.gz" if dataset_path.exists() and not force: click.echo( "Dataset already exists. Skipping download. " @@ -61,7 +80,7 @@ def download_dataset(path: Path, force: bool, no_migration: bool): elif dataset_path.exists(): click.echo("Dataset already exists. Re-downloading.") shutil.rmtree(dataset_path) - downlad_file(DATASET_URL, archive_path) + downlad_file(ckpturl, archive_path) click.echo("Extracting archive...") with tarfile.open(archive_path, "r:gz") as archive: archive.extractall(path)