Skip to content

Commit

Permalink
iterate on script
Browse files Browse the repository at this point in the history
  • Loading branch information
gschoeni committed Oct 17, 2023
1 parent 876b62b commit c4d58e4
Showing 1 changed file with 181 additions and 111 deletions.
292 changes: 181 additions & 111 deletions scripts/hf2oxen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datasets import load_dataset
from oxen.remote_repo import create_repo, get_repo
from oxen import LocalRepo
import urllib.request

def human_size(bytes, units=[' bytes','KB','MB','GB','TB', 'PB', 'EB']):
""" Returns a human readable string representation of bytes """
Expand All @@ -21,122 +22,191 @@ def query():
data = query()
return data

# argparse the name of the dataset
parser = argparse.ArgumentParser(description='Download a dataset from hugging face and upload to Oxen.')
# parse dataset as -d or --dataset
parser.add_argument('-d','--dataset', dest="dataset", required=True, help="Name of the dataset to download from hugging face")
parser.add_argument('-o','--output', dest="output", required=True, help="The output directory to save the dataset to")
parser.add_argument('-n', '--namespace', dest="namespace", default="ox", help="The oxen namespace to upload to")
parser.add_argument('--host', dest="host", default="hub.oxen.ai", help="The host to upload to")
args = parser.parse_args()

dataset_name = args.dataset
output_dir = args.output
namespace = args.namespace
host = args.host

api = HfApi()

info = api.repo_info(dataset_name, repo_type="dataset")
print(info)
print(info.description)
commits = api.list_repo_commits(dataset_name, repo_type="dataset")
commits.reverse()
print(f"Got {len(commits)} commits")

info = get_dataset_info(dataset_name)
print(info)
sizes = []
for key in info['dataset_info'].keys():
info_obj = info['dataset_info'][key]
if 'size_in_bytes' in info_obj:
size_in_bytes = info_obj['size_in_bytes']
else:
size_in_bytes = info_obj['dataset_size']
print(f"{key}: {human_size(size_in_bytes)}")
sizes.append(size_in_bytes)
sum_sizes = sum(sizes)
print(f"Dataset size: {human_size(sum_sizes)}")

if sum_sizes > 5_000_000_000:
print(f"Dataset size is {human_size(sum_sizes)}, this is greater than 5GB, do not continue")
exit(1)

# if dir exists, do not continue
output_dir = os.path.join(output_dir, dataset_name)
if os.path.exists(output_dir):
print(f"Directory {output_dir} exists, do not continue")
exit(1)

clean_name = dataset_name
if "/" in clean_name:
clean_name = dataset_name.replace("/", "_")

name = f"{namespace}/{clean_name}"
# Create Remote Repo
if get_repo(name, host=host):
print(f"Repo {name} exists, do not continue")
exit(1)

# create dir
os.makedirs(output_dir)

# TODO: Create repo with description and README.md based off of contents of dataset info
remote_repo = create_repo(name, host=host)
local_repo = LocalRepo(output_dir)
local_repo.init()
local_repo.set_remote("origin", remote_repo.url())

for commit in commits:
print(f"Loading commit: {commit}...")
def get_repo_info(dataset_name):
api = HfApi()

info = api.repo_info(dataset_name, repo_type="dataset")
print(info)
print(info.description)

info = get_dataset_info(dataset_name)
print(info)

print("\n\n")
print("="*80)

# download the dataset from hugging face
try:
hf_dataset = load_dataset(dataset_name, revision=commit.commit_id)
print(hf_dataset)
sizes = [0]
description = ""
if 'dataset_info' in info:
subsets = info['dataset_info'].keys()
for key in subsets:
info_obj = info['dataset_info'][key]
if 'size_in_bytes' in info_obj:
size_in_bytes = info_obj['size_in_bytes']
else:
size_in_bytes = info_obj['dataset_size']
print(f"\n====\n{key}: {human_size(size_in_bytes)}")
sizes.append(size_in_bytes)

subset_description = info_obj['description'].strip()
print(subset_description)
if description == "":
description = subset_description
else:
subsets = ["default"]

sum_sizes = sum(sizes)
print(f"Dataset Total Size: {human_size(sum_sizes)}")
print("="*80)
print("\n\n")

print(f"\n\nDescription:\n\n{description}\n\n")

return {"size": sum_sizes, "description": description, "subsets": subsets}

def download_dataset_subsets(dataset_name, subsets, local_repo, data_dir, commit=None):
for subset in subsets:
branch_name = subset

if len(subsets) == 1:
if commit:
print(f"\nCalling load_dataset('{dataset_name}', revision='{commit.commit_id}')...\n")
hf_dataset = load_dataset(dataset_name, revision=commit.commit_id)
else:
print(f"\nCalling load_dataset('{dataset_name}')...\n")
hf_dataset = load_dataset(dataset_name)
branch_name = "main"
else:
branch_names = [branch.name for branch in local_repo.branches()]
print(f"Branches: {branch_names}")
print(f"Checking out branch {branch_name}...")
if not branch_name in branch_names:
print(f"Creating branch {branch_name}...")
local_repo.checkout(branch_name, create=True)

if commit:
print(f"\nCalling load_dataset('{dataset_name}', '{subset}', revision='{commit.commit_id}')...\n")
hf_dataset = load_dataset(dataset_name, subset, revision=commit.commit_id)
else:
print(f"\nCalling load_dataset('{dataset_name}', '{subset}')...\n")
hf_dataset = load_dataset(dataset_name, subset)

for key, dataset in hf_dataset.items():
filename = os.path.join(output_dir, f"{key}.parquet")
filename = os.path.join(data_dir, f"{key}.parquet")
dataset.to_parquet(filename)
print(f"Adding {filename} to local repo")
local_repo.add(filename)


status = local_repo.status()
print(status)
if status.is_dirty():
print(f"✅ Committing {dataset_name} to {branch_name}...")

if commit:
commit_message = f"{commit.title}\n\n{commit.message}"
if commit.title == "" and commit.message == "":
commit_message = f"Update dataset from git commit {commit.commit_id}"
local_repo.commit(commit_message)
else:
local_repo.commit("Adding dataset")

print(f"Pushing {dataset_name} to {host}...")
local_repo.push(branch=branch_name)

def download_and_add_readme_if_exists(dataset_name, local_repo):
# Download the readme
try:
readme_url = f"https://huggingface.co/datasets/{dataset_name}/resolve/main/README.md"
readme_file = os.path.join(output_dir, "README.md")
print(f"Downloading {readme_url} to {readme_file}")
urllib.request.urlretrieve(readme_url, readme_file)

local_repo.add(readme_file)
local_repo.commit("Adding README.md")
local_repo.push()
except Exception as e:
print(f"Failed to download README.md from dataset {dataset_name}")
print(f"Got Exception: {e}")
error_str = f"{e}"
split_str = "Please pick one among the available configs: ["
if split_str in error_str:
config_options = error_str.split(split_str)[-1]
config_options = config_options.split("]")[0]
print(f"Available configs for {dataset_name}: {config_options}")
options = config_options.split(",")
for option in options:
option = option.replace("'", "").strip()
print(f"Download dataset {dataset_name} with option {option}")
hf_dataset = load_dataset(dataset_name, option, revision=commit.commit_id)
print(hf_dataset)

# info = hf_dataset.info
# print(info)

for key, dataset in hf_dataset.items():
filename = os.path.join(output_dir, f"{key}_{option}.parquet")
dataset.to_parquet(filename)
local_repo.add(filename)
except:
print(f"Failed to download dataset {dataset_name} with commit {commit}")
continue

status = local_repo.status()
commit_message = f"{commit.title}\n\n{commit.message}"
if status.is_dirty():
print(f"✅ Committing with message: {commit_message}...")

if commit_message == "":
commit.message = f"Update dataset {commit.commit_id}"

local_repo.commit(commit_message)
else:
print(f"🤷‍♂️ Skipping commit with message: {commit_message}...")

print(f"Uploading {dataset_name} to {host}...")
local_repo.push()
# if main
if __name__ == "__main__":

# argparse the name of the dataset
parser = argparse.ArgumentParser(description='Download a dataset from hugging face and upload to Oxen.')
# parse dataset as -d or --dataset
parser.add_argument('-d','--dataset', dest="dataset", required=True, help="Name of the dataset to download from hugging face")
parser.add_argument('-o','--output', dest="output", required=True, help="The output directory to save the dataset to")
parser.add_argument('-n', '--namespace', dest="namespace", default="ox", help="The oxen namespace to upload to")
parser.add_argument('--host', dest="host", default="hub.oxen.ai", help="The host to upload to")
args = parser.parse_args()

dataset_name = args.dataset
output_dir = args.output
namespace = args.namespace
host = args.host


# if dir exists, do not continue
output_dir = os.path.join(output_dir, dataset_name)
if os.path.exists(output_dir):
print(f"Directory {output_dir} exists, do not continue")
exit(1)

clean_name = dataset_name
if "/" in clean_name:
clean_name = dataset_name.replace("/", "_")

name = f"{namespace}/{clean_name}"
# Create Remote Repo
if get_repo(name, host=host):
print(f"Repo {name} exists, do not continue")
exit(1)

# create dirs
data_dir = os.path.join(output_dir, "data")
os.makedirs(data_dir)

# {"size": sum_sizes, "description": description, "subsets": subsets}
info = get_repo_info(dataset_name)
sum_sizes = info['size']
description = info['description']
subsets = info['subsets']

if sum_sizes > 5_000_000_000:
print(f"Dataset size is {human_size(sum_sizes)}, this is greater than 5GB, do not continue")
exit(1)

# Create Oxen Remote Repo
remote_repo = create_repo(name, description=description, host=host)
local_repo = LocalRepo(output_dir)
local_repo.init()
local_repo.set_remote("origin", remote_repo.url())

# Try to create README.md, some don't have it
download_and_add_readme_if_exists(dataset_name, local_repo)

# Try to process the commit history
api = HfApi()
commits = api.list_repo_commits(dataset_name, repo_type="dataset")
commits.reverse()
print(f"\nProcessing {len(commits)} commits\n")
for commit in commits:
print(f"Loading commit: {commit}...")

# download a specific from hugging face
try:
download_dataset_subsets(dataset_name, subsets, local_repo, data_dir, commit=commit)

except Exception as e:
print(f"Failed to download commit {commit} from dataset {dataset_name}")
print(f"Got Exception: {e}")


# Download the dataset with the base load_dataset function to get the latest version in case all the commit history fails, because sometimes the commit history is broken
local_repo.checkout("main")
subsets = ["default"]
if not os.path.exists(data_dir):
os.makedirs(data_dir)
download_dataset_subsets(dataset_name, subsets, local_repo, data_dir)


0 comments on commit c4d58e4

Please sign in to comment.